update 20250403
This commit is contained in:
@@ -19,3 +19,9 @@ api_chat/runtime/*
|
||||
api_chat/docs/*
|
||||
!api_chat/docs/.gitkeep
|
||||
|
||||
api_chat/GPT_weights
|
||||
api_chat/GPT_weights_v2
|
||||
api_chat/SoVITS_weights
|
||||
api_chat/SoVITS_weights_v2
|
||||
|
||||
api/logs
|
||||
@@ -7,7 +7,6 @@
|
||||
```
|
||||
API/
|
||||
├── api/ # 视觉分析和处理模块
|
||||
│ ├── producer/ # 生产者,分配任务
|
||||
│ ├── cpm_analyze.py # CPM_OCR分析
|
||||
│ ├── qwenvl_analyze.py # QwenVL_OCR分析
|
||||
│ ├── cpm_scene.py # CPM_场景分析
|
||||
@@ -17,22 +16,23 @@
|
||||
│ ├── face.py # 人脸检测
|
||||
│ ├── fall.py # 跌倒检测
|
||||
│ ├── pose.py # 姿态估计
|
||||
│ └── media.py # mediapipe 面部特征提取
|
||||
│ ├── media.py # mediapipe 面部特征提取
|
||||
| ├── start_services.sh #一键开始所有程序
|
||||
| └── stop.sh #一键停止所有程序
|
||||
├── api_chat/ # 聊天和语音处理模块
|
||||
│ ├── producer_chat/ # 聊天生产者
|
||||
│ ├── chat.py # 聊天功能
|
||||
│ ├── tts.py # 文字转语音
|
||||
│ ├── asr.py # 语音识别
|
||||
│ ├── chat.py # 聊天功能
|
||||
│ ├── tts.py # 文字转语音
|
||||
│ ├── asr.py # 语音识别
|
||||
│ ├── GPT_SoVITS/ # GPT_SoVITS模型集成,
|
||||
│ ├── sample/ # OpenBMB模型——学习音色,音色+文本内容
|
||||
│ ├── tools/ # GPT_SoVITS模型——工具函数
|
||||
│ ├── runtime/ # GPT_SoVITS模型——运行时函数
|
||||
│ ├── docs/ # GPT_SoVITS模型——文档
|
||||
│ ├── TEMP/ # OpenBMB模型临时文件夹
|
||||
│ └── before/ # 历史代码,可以忽略
|
||||
├── api_history/ # api历史代码,可以忽略
|
||||
├── chat_history/ # api_chat历史代码,可以忽略
|
||||
└── api_old/ # api历史代码,可以忽略
|
||||
| └── weight.json # GPT_SoVITS模型——权重
|
||||
|
|
||||
├── producer_chat/ # 聊天生产者
|
||||
├── producer/ # 算法生产者,分配任务
|
||||
└── README # 说明文档
|
||||
```
|
||||
|
||||
## 主要功能
|
||||
@@ -52,10 +52,10 @@
|
||||
- 多模型支持(通过Ollama)
|
||||
|
||||
## 使用说明
|
||||
### API 部分 http://dev.obscura.work/v1
|
||||
### API 部分 http://dev2.obscura.work/v1
|
||||
1. producer 目录 # 生产者,分配任务
|
||||
2. 服务器:222.186.10.253:8005
|
||||
3. kafka 配置:222.186.10.253:9092
|
||||
2. 服务器:222.186.20.67:8005
|
||||
3. kafka 配置:222.186.20.67:9092
|
||||
topic分配:
|
||||
- yolo: "yolo"
|
||||
- pose: "pose"
|
||||
@@ -68,7 +68,7 @@
|
||||
- mediapipe: "mediapipe"
|
||||
- compare: "compare"
|
||||
|
||||
4. redis 配置:150.158.144.159:13003
|
||||
4. redis 配置:222.186.20.67:6379
|
||||
db分配:
|
||||
- 4: "yolo"
|
||||
- 5: "pose"
|
||||
@@ -82,26 +82,26 @@
|
||||
- 30: "compare"
|
||||
|
||||
5. 模型配置:
|
||||
- YOLO = "/obscura/models/yolov8x.pt"
|
||||
- POSE = "/obscura/models/yolov8x-pose.pt"
|
||||
- QWEN = "/obscura/models/qwen/Qwen2-VL-2B-Instruct"
|
||||
- YOLO = "/obscura/models/yolo11n.pt"
|
||||
- POSE = "/obscura/models/yolo11n-pose.pt"
|
||||
- QWEN = "/obscura/models/qwen/Qwen2.5-VL-7B-Instruct"
|
||||
- FALL = "/obscura/models/yolov8n-fall.pt"
|
||||
- FACE = "/obscura/models/yolov8n-face.pt"
|
||||
- FACE = "/obscura/models/yolo11n-face.pt"
|
||||
- MEDIAPIPE = "/obscura/models/face_landmarker.task"
|
||||
- CPM(ollama) = "https://ffgregevrdcfyhtnhyudvr.myfastools.com/api/generate"
|
||||
- CPM(ollama) = "https://222.186.20.67:11435/api/generate"
|
||||
6. 上传文件及结果保存目录:
|
||||
- UPLOAD_DIR = "/obscura/task/upload"
|
||||
- RESULT_DIR = "/obscura/task/result"
|
||||
|
||||
### API_Chat 部分 http://dev.obscura.work/v1_chat
|
||||
### API_Chat 部分 http://dev2.obscura.work/v1_chat
|
||||
1. producer_chat 目录 # 聊天生产者
|
||||
2. 服务器:222.186.10.253:8008
|
||||
3. kafka 配置:222.186.10.253:9092
|
||||
2. 服务器:222.186.20.67:8008
|
||||
3. kafka 配置:222.186.20.67:9092
|
||||
topic分配:
|
||||
- asr: "asr"
|
||||
- chat: "chat"
|
||||
- tts: "tts"
|
||||
4. redis 配置:150.158.144.159:13003
|
||||
4. redis 配置:222.186.20.67:6379
|
||||
db分配:
|
||||
- 2: "api key"
|
||||
- 3: "api使用情况"
|
||||
|
||||
Binary file not shown.
+66
-51
@@ -1,4 +1,11 @@
|
||||
import os
|
||||
# GPU 环境变量设置
|
||||
|
||||
# 首先检查可用的 GPU
|
||||
import torch
|
||||
|
||||
|
||||
device = torch.device('cuda:1')
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
@@ -8,30 +15,17 @@ import json
|
||||
from kafka import KafkaConsumer
|
||||
import threading
|
||||
import redis
|
||||
import torch
|
||||
from config import *
|
||||
import insightface
|
||||
from insightface.app import FaceAnalysis
|
||||
from insightface.utils import face_align
|
||||
from config import (
|
||||
KAFKA_BROKER, KAFKA_GROUP_ID_PREFIX, WORKER_CONFIGS,
|
||||
REDIS_HOST, REDIS_PORT, REDIS_PASSWORD, MAIN_REDIS_DB,
|
||||
UPLOAD_DIR, RESULT_DIR, DEEPFACE_MODEL_PATH
|
||||
)
|
||||
from deepface import DeepFace
|
||||
|
||||
|
||||
torch.cuda.set_device(1)
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
# 配置
|
||||
|
||||
KAFKA_BROKER = KAFKA_BROKER
|
||||
KAFKA_TOPIC = WORKER_CONFIGS["compare"]["kafka_topic"]
|
||||
KAFKA_GROUP_ID = f"compare_{KAFKA_GROUP_ID_PREFIX}"
|
||||
|
||||
REDIS_HOST = REDIS_HOST
|
||||
REDIS_PORT = REDIS_PORT
|
||||
REDIS_PASSWORD = REDIS_PASSWORD
|
||||
REDIS_DB = WORKER_CONFIGS["compare"]["redis_db"] # Worker使用的Redis DB
|
||||
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
|
||||
|
||||
UPLOAD_DIR = UPLOAD_DIR
|
||||
RESULT_DIR = RESULT_DIR
|
||||
|
||||
# 初始化 Kafka
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
@@ -47,7 +41,7 @@ redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
db=WORKER_CONFIGS["compare"]["redis_db"]
|
||||
)
|
||||
|
||||
main_redis_client = Redis(
|
||||
@@ -60,18 +54,44 @@ main_redis_client = Redis(
|
||||
|
||||
class FaceComparator:
|
||||
def __init__(self):
|
||||
print("初始化 InsightFace...")
|
||||
self.app = FaceAnalysis(name='buffalo_l', allowed_modules=['detection', 'recognition', 'genderage'])
|
||||
self.app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
print(f"InsightFace 初始化完成,使用设备: {device}")
|
||||
|
||||
print("初始化 DeepFace...")
|
||||
self.model_path = DEEPFACE_MODEL_PATH
|
||||
os.environ["DEEPFACE_HOME"] = self.model_path
|
||||
|
||||
def detect(self, frame):
|
||||
"""检测人脸并返回所有特征"""
|
||||
try:
|
||||
faces = self.app.get(frame)
|
||||
return faces
|
||||
# 直接使用 represent 获取特征向量
|
||||
embeddings = DeepFace.represent(
|
||||
frame,
|
||||
model_name="Facenet512",
|
||||
detector_backend='retinaface',
|
||||
align=True, # 添加对齐选项
|
||||
enforce_detection=False
|
||||
)
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
embeddings = [embeddings]
|
||||
|
||||
# 获取其他属性
|
||||
results = DeepFace.analyze(
|
||||
frame,
|
||||
actions=['age', 'gender', 'race', 'emotion'],
|
||||
detector_backend='retinaface',
|
||||
enforce_detection=False
|
||||
)
|
||||
|
||||
if not isinstance(results, list):
|
||||
results = [results]
|
||||
|
||||
# 合并结果
|
||||
for i in range(len(results)):
|
||||
if i < len(embeddings):
|
||||
results[i]['embedding'] = embeddings[i]['embedding'] # 注意这里的变化
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"人脸检测出错: {str(e)}")
|
||||
return []
|
||||
|
||||
def format_results(self, faces):
|
||||
@@ -79,24 +99,24 @@ class FaceComparator:
|
||||
print("开始格式化检测结果...")
|
||||
try:
|
||||
formatted_results = []
|
||||
for i, face in enumerate(faces):
|
||||
# 设置默认值,避免None导致的错误
|
||||
for face in faces:
|
||||
region = face.get('region', {})
|
||||
result = {
|
||||
'bbox': face.bbox.tolist(),
|
||||
'kps': face.kps.tolist(),
|
||||
'gender': int(face.gender) if hasattr(face, 'gender') and face.gender is not None else -1,
|
||||
'age': float(face.age) if hasattr(face, 'age') and face.age is not None else 0.0,
|
||||
'det_score': float(face.det_score),
|
||||
'embedding': face.embedding.tolist()
|
||||
'bbox': [
|
||||
region.get('x', 0),
|
||||
region.get('y', 0),
|
||||
region.get('w', 0),
|
||||
region.get('h', 0)
|
||||
],
|
||||
'gender': 1 if face.get('gender', '') == 'Man' else 0,
|
||||
'age': float(face.get('age', 0)),
|
||||
'det_score': 1.0,
|
||||
'embedding': face.get('embedding', []) # embedding 已经是列表格式
|
||||
}
|
||||
formatted_results.append(result)
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
print(f"格式化结果时出错: {str(e)}")
|
||||
print(f"错误类型: {type(e)}")
|
||||
import traceback
|
||||
print(f"错误堆栈: {traceback.format_exc()}")
|
||||
return []
|
||||
|
||||
def draw_results(self, frame, results):
|
||||
@@ -106,32 +126,27 @@ class FaceComparator:
|
||||
annotated_frame = frame.copy()
|
||||
|
||||
if not results:
|
||||
print("没有检测结果,返回原始图像")
|
||||
return annotated_frame
|
||||
|
||||
for i, r in enumerate(results):
|
||||
for r in results:
|
||||
bbox = r['bbox']
|
||||
kps = r['kps']
|
||||
x, y, w, h = list(map(int, bbox))
|
||||
|
||||
cv2.rectangle(annotated_frame,
|
||||
(int(bbox[0]), int(bbox[1])),
|
||||
(int(bbox[2]), int(bbox[3])),
|
||||
(x, y),
|
||||
(x + w, y + h),
|
||||
(0, 255, 0), 2)
|
||||
|
||||
for x, y in kps:
|
||||
cv2.circle(annotated_frame,
|
||||
(int(x), int(y)),
|
||||
3, (255, 255, 0), -1)
|
||||
|
||||
gender_text = 'Male' if r['gender'] == 1 else 'Female' if r['gender'] == 0 else 'Unknown'
|
||||
gender_text = 'Male' if r['gender'] == 1 else 'Female'
|
||||
label = f"Age: {int(r['age'])} Gender: {gender_text}"
|
||||
cv2.putText(annotated_frame,
|
||||
label,
|
||||
(int(bbox[0]), int(bbox[1]-10)),
|
||||
(x, y-10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 255, 0),
|
||||
2)
|
||||
|
||||
return annotated_frame
|
||||
|
||||
except Exception as e:
|
||||
|
||||
+101
@@ -0,0 +1,101 @@
|
||||
# config.py
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
# Kafka配置
|
||||
KAFKA_BROKER = "222.186.20.67:9092"
|
||||
KAFKA_GROUP_ID_PREFIX = "group"
|
||||
|
||||
# Redis配置
|
||||
REDIS_HOST = "222.186.20.67"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
MAIN_REDIS_DB = 30
|
||||
REDIS_API_DB = 31
|
||||
REDIS_API_USAGE_DB = 32
|
||||
# 目录配置
|
||||
UPLOAD_DIR = "/obscura/task/upload"
|
||||
RESULT_DIR = "/obscura/task/result"
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 模型配置
|
||||
YOLO_MODEL_PATH = "/obscura/models/yolo11n.pt"
|
||||
POSE_MODEL_PATH = "/obscura/models/yolo11n-pose.pt"
|
||||
QWEN_MODEL_PATH = "/obscura/models/QWEN/Qwen2___5-VL-7B-Instruct"
|
||||
FALL_MODEL_PATH = "/obscura/models/yolov8n-fall.pt"
|
||||
FACE_MODEL_PATH = "/obscura/models/yolo11n-face.pt"
|
||||
MEDIAPIPE_MODEL_PATH = "/obscura/models/face_landmarker.task"
|
||||
DEEPFACE_MODEL_PATH = "/obscura/models"
|
||||
# Ollama配置
|
||||
OLLAMA_URLS = [
|
||||
"http://222.186.20.67:11434/api/generate",
|
||||
"http://222.186.20.67:11435/api/generate",
|
||||
"http://222.186.20.67:11436/api/generate",
|
||||
"http://222.186.20.67:11437/api/generate",
|
||||
"http://222.186.20.67:11438/api/generate",
|
||||
"http://222.186.20.67:11439/api/generate",
|
||||
"http://222.186.20.67:11440/api/generate",
|
||||
"http://222.186.20.67:11441/api/generate",
|
||||
# 在这里添加更多的API地址
|
||||
]
|
||||
# 随机选择一个API
|
||||
OLLAMA_URL = random.choice(OLLAMA_URLS)
|
||||
|
||||
# 各个worker的配置
|
||||
WORKER_CONFIGS = {
|
||||
"yolo": {
|
||||
"kafka_topic": "yolo",
|
||||
"redis_db": 33,
|
||||
},
|
||||
"pose": {
|
||||
"kafka_topic": "pose",
|
||||
"redis_db": 34,
|
||||
},
|
||||
"qwenvl": {
|
||||
"kafka_topic": "qwenvl",
|
||||
"redis_db": 35,
|
||||
},
|
||||
"qwenvl_analyze": {
|
||||
"kafka_topic": "qwenvl_analyze",
|
||||
"redis_db": 36,
|
||||
},
|
||||
"cpm": {
|
||||
"kafka_topic": "cpm",
|
||||
"redis_db": 37,
|
||||
},
|
||||
"cpm_analyze": {
|
||||
"kafka_topic": "cpm_analyze",
|
||||
"redis_db": 38,
|
||||
},
|
||||
"fall": {
|
||||
"kafka_topic": "fall",
|
||||
"redis_db": 39,
|
||||
},
|
||||
"face": {
|
||||
"kafka_topic": "face",
|
||||
"redis_db": 40,
|
||||
},
|
||||
"mediapipe": {
|
||||
"kafka_topic": "mediapipe",
|
||||
"redis_db": 41,
|
||||
},
|
||||
"compare": {
|
||||
"kafka_topic": "compare",
|
||||
"redis_db": 42,
|
||||
}
|
||||
}
|
||||
|
||||
# GPU设置
|
||||
CUDA_DEVICE_0 = "cuda:0"
|
||||
CUDA_DEVICE_1 = "cuda:1"
|
||||
CUDA_DEVICE_2 = "cuda:2"
|
||||
CUDA_DEVICE_3 = "cuda:3"
|
||||
CUDA_DEVICE_4 = "cuda:4"
|
||||
CUDA_DEVICE_5 = "cuda:5"
|
||||
CUDA_DEVICE_6 = "cuda:6"
|
||||
CUDA_DEVICE_7 = "cuda:7"
|
||||
CUDA_DEVICE_8 = "cuda:8"
|
||||
+5
-10
@@ -13,22 +13,17 @@ import requests
|
||||
import base64
|
||||
import traceback
|
||||
import json
|
||||
from config import *
|
||||
from config import (
|
||||
KAFKA_BROKER, KAFKA_GROUP_ID_PREFIX, WORKER_CONFIGS,
|
||||
REDIS_HOST, REDIS_PORT, REDIS_PASSWORD, MAIN_REDIS_DB,
|
||||
UPLOAD_DIR, RESULT_DIR, OLLAMA_URL
|
||||
)
|
||||
|
||||
# 配置
|
||||
OLLAMA_URL = OLLAMA_URL
|
||||
KAFKA_BROKER = KAFKA_BROKER
|
||||
KAFKA_TOPIC = WORKER_CONFIGS["cpm_analyze"]["kafka_topic"]
|
||||
KAFKA_GROUP_ID = f"cpm_analyze_{KAFKA_GROUP_ID_PREFIX}"
|
||||
|
||||
REDIS_HOST = REDIS_HOST
|
||||
REDIS_PORT = REDIS_PORT
|
||||
REDIS_PASSWORD = REDIS_PASSWORD
|
||||
REDIS_DB = WORKER_CONFIGS["cpm_analyze"]["redis_db"] # Worker使用的Redis DB
|
||||
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
|
||||
|
||||
UPLOAD_DIR = UPLOAD_DIR
|
||||
RESULT_DIR = RESULT_DIR
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
|
||||
+6
-12
@@ -13,22 +13,16 @@ import requests
|
||||
import base64
|
||||
import traceback
|
||||
import json
|
||||
from config import *
|
||||
|
||||
from config import (
|
||||
KAFKA_BROKER, KAFKA_GROUP_ID_PREFIX, WORKER_CONFIGS,
|
||||
REDIS_HOST, REDIS_PORT, REDIS_PASSWORD, MAIN_REDIS_DB,
|
||||
UPLOAD_DIR, RESULT_DIR, OLLAMA_URL
|
||||
)
|
||||
# 配置
|
||||
OLLAMA_URL = OLLAMA_URL
|
||||
KAFKA_BROKER = KAFKA_BROKER
|
||||
KAFKA_TOPIC = WORKER_CONFIGS["cpm"]["kafka_topic"]
|
||||
KAFKA_GROUP_ID = f"cpm_{KAFKA_GROUP_ID_PREFIX}"
|
||||
|
||||
REDIS_HOST = REDIS_HOST
|
||||
REDIS_PORT = REDIS_PORT
|
||||
REDIS_PASSWORD = REDIS_PASSWORD
|
||||
REDIS_DB = WORKER_CONFIGS["cpm"]["redis_db"] # Worker使用的Redis DB
|
||||
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
|
||||
|
||||
UPLOAD_DIR = UPLOAD_DIR
|
||||
RESULT_DIR = RESULT_DIR
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
@@ -222,7 +216,7 @@ class MediaAnalysisSystem:
|
||||
"furniture": []
|
||||
}
|
||||
|
||||
environments = ["办公室", "室内", "��外", "会议室"]
|
||||
environments = ["办公室", "室内", "室外", "会议室"]
|
||||
for env in environments:
|
||||
if env in answer.lower():
|
||||
info["environment"] = env
|
||||
|
||||
+69
-57
@@ -9,23 +9,19 @@ from kafka import KafkaConsumer
|
||||
import threading
|
||||
import redis
|
||||
import torch
|
||||
from config import *
|
||||
from config import (
|
||||
KAFKA_BROKER, KAFKA_GROUP_ID_PREFIX, WORKER_CONFIGS,
|
||||
REDIS_HOST, REDIS_PORT, REDIS_PASSWORD, MAIN_REDIS_DB,
|
||||
UPLOAD_DIR, RESULT_DIR, FACE_MODEL_PATH
|
||||
)
|
||||
torch.cuda.set_device(1)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = FACE_MODEL_PATH
|
||||
KAFKA_BROKER = KAFKA_BROKER
|
||||
KAFKA_TOPIC = WORKER_CONFIGS["face"]["kafka_topic"]
|
||||
KAFKA_GROUP_ID = f"face_{KAFKA_GROUP_ID_PREFIX}"
|
||||
|
||||
REDIS_HOST = REDIS_HOST
|
||||
REDIS_PORT = REDIS_PORT
|
||||
REDIS_PASSWORD = REDIS_PASSWORD
|
||||
REDIS_DB = WORKER_CONFIGS["face"]["redis_db"] # Worker使用的Redis DB
|
||||
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
|
||||
|
||||
UPLOAD_DIR = UPLOAD_DIR
|
||||
RESULT_DIR = RESULT_DIR
|
||||
|
||||
# 初始化 Kafka
|
||||
consumer = KafkaConsumer(
|
||||
@@ -64,49 +60,55 @@ class faceDetector:
|
||||
formatted_results = []
|
||||
for r in results:
|
||||
boxes = r.boxes
|
||||
keypoints = r.keypoints
|
||||
# 获取原始图像和模型输出的尺寸比例
|
||||
orig_h, orig_w = original_shape[:2]
|
||||
model_h, model_w = r.orig_shape
|
||||
scale_x, scale_y = orig_w / model_w, orig_h / model_h
|
||||
|
||||
for i in range(len(boxes)):
|
||||
box = boxes[i]
|
||||
kpts = keypoints[i]
|
||||
try:
|
||||
box = boxes[i]
|
||||
bbox = box.xyxy[0].cpu().numpy()
|
||||
bbox_scaled = [
|
||||
bbox[0] * scale_x, bbox[1] * scale_y,
|
||||
bbox[2] * scale_x, bbox[3] * scale_y
|
||||
]
|
||||
|
||||
# 调整边界框坐标以适应原始图像大小
|
||||
orig_h, orig_w = original_shape[:2]
|
||||
model_h, model_w = r.orig_shape
|
||||
scale_x, scale_y = orig_w / model_w, orig_h / model_h
|
||||
result = {
|
||||
"bbox": bbox_scaled,
|
||||
"confidence": box.conf.item(),
|
||||
}
|
||||
|
||||
bbox = box.xyxy[0].cpu().numpy()
|
||||
bbox_scaled = [
|
||||
bbox[0] * scale_x, bbox[1] * scale_y,
|
||||
bbox[2] * scale_x, bbox[3] * scale_y
|
||||
]
|
||||
# 如果有关键点信息,则添加到结果中
|
||||
if hasattr(r, 'keypoints') and r.keypoints is not None:
|
||||
kpts = r.keypoints[i]
|
||||
if kpts is not None and kpts.xy is not None and len(kpts.xy) > 0:
|
||||
kpts_scaled = kpts.xy[0].cpu().numpy() * np.array([scale_x, scale_y])
|
||||
result["keypoints"] = kpts_scaled.tolist()
|
||||
|
||||
# 调整关键点坐标以适应原始图像大小
|
||||
kpts_scaled = kpts.xy[0].cpu().numpy() * np.array([scale_x, scale_y])
|
||||
formatted_results.append(result)
|
||||
except Exception as e:
|
||||
print(f"处理单个检测结果时出错: {str(e)}")
|
||||
continue
|
||||
|
||||
formatted_results.append({
|
||||
"bbox": bbox_scaled,
|
||||
"confidence": box.conf.item(),
|
||||
"keypoints": kpts_scaled.tolist()
|
||||
})
|
||||
return formatted_results
|
||||
|
||||
def draw_results(self, frame, results):
|
||||
annotated_frame = frame.copy()
|
||||
for r in results:
|
||||
bbox = r["bbox"]
|
||||
keypoints = r["keypoints"]
|
||||
|
||||
# 绘制边界框
|
||||
cv2.rectangle(annotated_frame,
|
||||
(int(bbox[0]), int(bbox[1])),
|
||||
(int(bbox[2]), int(bbox[3])),
|
||||
(0, 255, 0), 2)
|
||||
|
||||
# 绘制关键点
|
||||
for kp in keypoints:
|
||||
cv2.circle(annotated_frame,
|
||||
(int(kp[0]), int(kp[1])),
|
||||
5, (255, 0, 0), -1)
|
||||
# 如果有关键点信息,则绘制关键点
|
||||
if "keypoints" in r:
|
||||
for kp in r["keypoints"]:
|
||||
cv2.circle(annotated_frame,
|
||||
(int(kp[0]), int(kp[1])),
|
||||
5, (255, 0, 0), -1)
|
||||
return annotated_frame
|
||||
|
||||
|
||||
@@ -217,15 +219,19 @@ def process_task():
|
||||
result_path = os.path.join(RESULT_DIR, result_filename)
|
||||
cv2.imwrite(result_path, annotated_img)
|
||||
|
||||
redis_client.hmset(f"face_result:{task_id}", {
|
||||
"result": json.dumps(json_results),
|
||||
"result_file": result_filename
|
||||
})
|
||||
main_redis_client.hmset(f"task:{task_id}", {
|
||||
"status": "completed",
|
||||
"result_type": "face",
|
||||
"result_key": f"face_result:{task_id}"
|
||||
})
|
||||
redis_client.hset(f"face_result:{task_id}",
|
||||
mapping={
|
||||
"result": json.dumps(json_results),
|
||||
"result_file": result_filename
|
||||
}
|
||||
)
|
||||
main_redis_client.hset(f"task:{task_id}",
|
||||
mapping={
|
||||
"status": "completed",
|
||||
"result_type": "face",
|
||||
"result_key": f"face_result:{task_id}"
|
||||
}
|
||||
)
|
||||
print(f"图像 {filename} 处理完成,结果已保存")
|
||||
else:
|
||||
print(f"图像 {filename} 处理失败")
|
||||
@@ -235,25 +241,31 @@ def process_task():
|
||||
json_results = process_video(file_path)
|
||||
if json_results:
|
||||
result_filename = f"face_{filename}"
|
||||
redis_client.hmset(f"face_result:{task_id}", {
|
||||
"result": json.dumps(json_results),
|
||||
"result_file": result_filename
|
||||
})
|
||||
main_redis_client.hmset(f"task:{task_id}", {
|
||||
"status": "completed",
|
||||
"result_type": "face",
|
||||
"result_key": f"face_result:{task_id}"
|
||||
})
|
||||
redis_client.hset(f"face_result:{task_id}",
|
||||
mapping={
|
||||
"result": json.dumps(json_results),
|
||||
"result_file": result_filename
|
||||
}
|
||||
)
|
||||
main_redis_client.hset(f"task:{task_id}",
|
||||
mapping={
|
||||
"status": "completed",
|
||||
"result_type": "face",
|
||||
"result_key": f"face_result:{task_id}"
|
||||
}
|
||||
)
|
||||
print(f"视频 {filename} 处理完成,结果已保存")
|
||||
else:
|
||||
print(f"视频 {filename} 处理失败")
|
||||
main_redis_client.hset(f"task:{task_id}", "status", "failed")
|
||||
except Exception as e:
|
||||
print(f"处理任务 {task_id} 时出错: {str(e)}")
|
||||
main_redis_client.hmset(f"task:{task_id}", {
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
})
|
||||
main_redis_client.hset(f"task:{task_id}",
|
||||
mapping={
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
|
||||
def listen_redis_changes():
|
||||
|
||||
+5
-9
@@ -8,21 +8,17 @@ import json
|
||||
from kafka import KafkaConsumer
|
||||
import threading
|
||||
import redis
|
||||
from config import *
|
||||
from config import (
|
||||
KAFKA_BROKER, KAFKA_GROUP_ID_PREFIX, WORKER_CONFIGS,
|
||||
REDIS_HOST, REDIS_PORT, REDIS_PASSWORD, MAIN_REDIS_DB,
|
||||
UPLOAD_DIR, RESULT_DIR, FALL_MODEL_PATH
|
||||
)
|
||||
# 配置
|
||||
MODEL_PATH = FALL_MODEL_PATH
|
||||
KAFKA_BROKER = KAFKA_BROKER
|
||||
KAFKA_TOPIC = WORKER_CONFIGS["fall"]["kafka_topic"]
|
||||
KAFKA_GROUP_ID = f"fall_{KAFKA_GROUP_ID_PREFIX}"
|
||||
|
||||
REDIS_HOST = REDIS_HOST
|
||||
REDIS_PORT = REDIS_PORT
|
||||
REDIS_PASSWORD = REDIS_PASSWORD
|
||||
REDIS_DB = WORKER_CONFIGS["fall"]["redis_db"] # Worker使用的Redis DB
|
||||
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
|
||||
|
||||
UPLOAD_DIR = UPLOAD_DIR
|
||||
RESULT_DIR = RESULT_DIR
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
|
||||
+5
-9
@@ -11,21 +11,17 @@ import torch
|
||||
import mediapipe as mp
|
||||
from mediapipe.tasks import python
|
||||
from mediapipe.tasks.python import vision
|
||||
from config import *
|
||||
from config import (
|
||||
KAFKA_BROKER, KAFKA_GROUP_ID_PREFIX, WORKER_CONFIGS,
|
||||
REDIS_HOST, REDIS_PORT, REDIS_PASSWORD, MAIN_REDIS_DB,
|
||||
UPLOAD_DIR, RESULT_DIR, MEDIAPIPE_MODEL_PATH
|
||||
)
|
||||
# 配置
|
||||
MODEL_PATH = MEDIAPIPE_MODEL_PATH
|
||||
KAFKA_BROKER = KAFKA_BROKER
|
||||
KAFKA_TOPIC = WORKER_CONFIGS["mediapipe"]["kafka_topic"]
|
||||
KAFKA_GROUP_ID = f"mediapipe_{KAFKA_GROUP_ID_PREFIX}"
|
||||
|
||||
REDIS_HOST = REDIS_HOST
|
||||
REDIS_PORT = REDIS_PORT
|
||||
REDIS_PASSWORD = REDIS_PASSWORD
|
||||
REDIS_DB = WORKER_CONFIGS["mediapipe"]["redis_db"] # Worker使用的Redis DB
|
||||
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
|
||||
|
||||
UPLOAD_DIR = UPLOAD_DIR
|
||||
RESULT_DIR = RESULT_DIR
|
||||
|
||||
|
||||
# Ensure directories exist
|
||||
|
||||
+32
-28
@@ -9,22 +9,18 @@ from kafka import KafkaConsumer
|
||||
import threading
|
||||
import redis
|
||||
import torch
|
||||
from config import *
|
||||
from config import (
|
||||
KAFKA_BROKER, KAFKA_GROUP_ID_PREFIX, WORKER_CONFIGS,
|
||||
REDIS_HOST, REDIS_PORT, REDIS_PASSWORD, MAIN_REDIS_DB,
|
||||
UPLOAD_DIR, RESULT_DIR, POSE_MODEL_PATH,CUDA_DEVICE_1
|
||||
)
|
||||
torch.cuda.set_device(CUDA_DEVICE_1)
|
||||
# 配置
|
||||
MODEL_PATH = POSE_MODEL_PATH
|
||||
KAFKA_BROKER = KAFKA_BROKER
|
||||
KAFKA_TOPIC = WORKER_CONFIGS["pose"]["kafka_topic"]
|
||||
KAFKA_GROUP_ID = f"pose_{KAFKA_GROUP_ID_PREFIX}"
|
||||
|
||||
REDIS_HOST = REDIS_HOST
|
||||
REDIS_PORT = REDIS_PORT
|
||||
REDIS_PASSWORD = REDIS_PASSWORD
|
||||
REDIS_DB = WORKER_CONFIGS["pose"]["redis_db"] # Worker使用的Redis DB
|
||||
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
|
||||
|
||||
UPLOAD_DIR = UPLOAD_DIR
|
||||
RESULT_DIR = RESULT_DIR
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
@@ -219,15 +215,19 @@ def process_task():
|
||||
result_path = os.path.join(RESULT_DIR, result_filename)
|
||||
cv2.imwrite(result_path, annotated_img)
|
||||
|
||||
redis_client.hmset(f"pose_result:{task_id}", {
|
||||
"result": json.dumps(json_results),
|
||||
"result_file": result_filename
|
||||
})
|
||||
main_redis_client.hmset(f"task:{task_id}", {
|
||||
"status": "completed",
|
||||
"result_type": "pose",
|
||||
"result_key": f"pose_result:{task_id}"
|
||||
})
|
||||
redis_client.hset(f"pose_result:{task_id}",
|
||||
mapping={
|
||||
"result": json.dumps(json_results),
|
||||
"result_file": result_filename
|
||||
}
|
||||
)
|
||||
main_redis_client.hset(f"task:{task_id}",
|
||||
mapping={
|
||||
"status": "completed",
|
||||
"result_type": "pose",
|
||||
"result_key": f"pose_result:{task_id}"
|
||||
}
|
||||
)
|
||||
print(f"图像 {filename} 处理完成,结果已保存")
|
||||
else:
|
||||
print(f"图像 {filename} 处理失败")
|
||||
@@ -237,22 +237,26 @@ def process_task():
|
||||
json_results = process_video(file_path)
|
||||
if json_results:
|
||||
result_filename = f"pose_{filename}"
|
||||
redis_client.hmset(f"pose_result:{task_id}", {
|
||||
"result": json.dumps(json_results),
|
||||
"result_file": result_filename
|
||||
})
|
||||
main_redis_client.hmset(f"task:{task_id}", {
|
||||
"status": "completed",
|
||||
"result_type": "pose",
|
||||
"result_key": f"pose_result:{task_id}"
|
||||
})
|
||||
redis_client.hset(f"pose_result:{task_id}",
|
||||
mapping={
|
||||
"result": json.dumps(json_results),
|
||||
"result_file": result_filename
|
||||
}
|
||||
)
|
||||
main_redis_client.hset(f"task:{task_id}",
|
||||
mapping={
|
||||
"status": "completed",
|
||||
"result_type": "pose",
|
||||
"result_key": f"pose_result:{task_id}"
|
||||
}
|
||||
)
|
||||
print(f"视频 {filename} 处理完成,结果已保存")
|
||||
else:
|
||||
print(f"视频 {filename} 处理失败")
|
||||
main_redis_client.hset(f"task:{task_id}", "status", "failed")
|
||||
except Exception as e:
|
||||
print(f"处理任务 {task_id} 时出错: {str(e)}")
|
||||
main_redis_client.hmset(f"task:{task_id}", {
|
||||
main_redis_client.hset(f"task:{task_id}", {
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
+11
-14
@@ -1,9 +1,8 @@
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from kafka import KafkaConsumer
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
@@ -12,22 +11,20 @@ from redis import Redis
|
||||
import io
|
||||
import re
|
||||
import threading
|
||||
from config import *
|
||||
import torch
|
||||
from config import (
|
||||
KAFKA_BROKER, KAFKA_GROUP_ID_PREFIX, WORKER_CONFIGS,
|
||||
REDIS_HOST, REDIS_PORT, REDIS_PASSWORD, MAIN_REDIS_DB,
|
||||
UPLOAD_DIR, RESULT_DIR, QWEN_MODEL_PATH
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = QWEN_MODEL_PATH
|
||||
KAFKA_BROKER = KAFKA_BROKER
|
||||
KAFKA_TOPIC = WORKER_CONFIGS["qwenvl_analyze"]["kafka_topic"]
|
||||
KAFKA_GROUP_ID = f"qwenvl_analyze_{KAFKA_GROUP_ID_PREFIX}"
|
||||
|
||||
REDIS_HOST = REDIS_HOST
|
||||
REDIS_PORT = REDIS_PORT
|
||||
REDIS_PASSWORD = REDIS_PASSWORD
|
||||
REDIS_DB = WORKER_CONFIGS["qwenvl_analyze"]["redis_db"] # Worker使用的Redis DB
|
||||
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
|
||||
|
||||
UPLOAD_DIR = UPLOAD_DIR
|
||||
RESULT_DIR = RESULT_DIR
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
@@ -59,13 +56,13 @@ main_redis_client = Redis(
|
||||
|
||||
|
||||
# 初始化模型
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
MODEL_PATH, torch_dtype="auto", device_map="cuda:0"
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
MODEL_PATH, torch_dtype=torch.bfloat16, device_map="cuda:2"
|
||||
)
|
||||
|
||||
min_pixels = 128*28*28
|
||||
max_pixels = 512*28*28
|
||||
processor = AutoProcessor.from_pretrained(MODEL_PATH, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
processor = AutoProcessor.from_pretrained(MODEL_PATH, min_pixels=min_pixels, max_pixels=max_pixels, use_fast=True)
|
||||
|
||||
class MediaAnalysisSystem:
|
||||
def __init__(self, model, processor):
|
||||
@@ -136,7 +133,7 @@ class MediaAnalysisSystem:
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to('cuda:0')
|
||||
inputs = inputs.to('cuda:2')
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=2048)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
|
||||
+11
-14
@@ -1,9 +1,8 @@
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from kafka import KafkaConsumer
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
@@ -12,22 +11,20 @@ from redis import Redis
|
||||
import io
|
||||
import re
|
||||
import threading
|
||||
from config import *
|
||||
import torch
|
||||
from config import (
|
||||
KAFKA_BROKER, KAFKA_GROUP_ID_PREFIX, WORKER_CONFIGS,
|
||||
REDIS_HOST, REDIS_PORT, REDIS_PASSWORD, MAIN_REDIS_DB,
|
||||
UPLOAD_DIR, RESULT_DIR, QWEN_MODEL_PATH
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = QWEN_MODEL_PATH
|
||||
KAFKA_BROKER = KAFKA_BROKER
|
||||
KAFKA_TOPIC = WORKER_CONFIGS["qwenvl"]["kafka_topic"]
|
||||
KAFKA_GROUP_ID = f"qwenvl_{KAFKA_GROUP_ID_PREFIX}"
|
||||
|
||||
REDIS_HOST = REDIS_HOST
|
||||
REDIS_PORT = REDIS_PORT
|
||||
REDIS_PASSWORD = REDIS_PASSWORD
|
||||
REDIS_DB = WORKER_CONFIGS["qwenvl"]["redis_db"] # Worker使用的Redis DB
|
||||
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
|
||||
|
||||
UPLOAD_DIR = UPLOAD_DIR
|
||||
RESULT_DIR = RESULT_DIR
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
@@ -59,13 +56,13 @@ main_redis_client = Redis(
|
||||
|
||||
|
||||
# 初始化模型
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
MODEL_PATH, torch_dtype="auto", device_map="cuda:1"
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
MODEL_PATH, torch_dtype=torch.bfloat16, device_map="cuda:3"
|
||||
)
|
||||
|
||||
min_pixels = 128*28*28
|
||||
max_pixels = 512*28*28
|
||||
processor = AutoProcessor.from_pretrained(MODEL_PATH, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
processor = AutoProcessor.from_pretrained(MODEL_PATH,min_pixels=min_pixels, max_pixels=max_pixels, use_fast=True)
|
||||
|
||||
class MediaAnalysisSystem:
|
||||
def __init__(self, model, processor):
|
||||
@@ -137,7 +134,7 @@ class MediaAnalysisSystem:
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to('cuda:1')
|
||||
inputs = inputs.to('cuda:3')
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=2048)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
fastapi
|
||||
uvicorn
|
||||
python-multipart
|
||||
kafka-python==2.0.2
|
||||
six>=1.10.0 # 添加six包来解决kafka依赖问题
|
||||
redis
|
||||
python-dotenv
|
||||
requests
|
||||
pydantic[email]
|
||||
pydub
|
||||
httpx
|
||||
sqlalchemy
|
||||
passlib[bcrypt]
|
||||
pymysql
|
||||
python-jose[cryptography]
|
||||
Pillow # 添加用于图片处理
|
||||
Executable
+36
@@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 设置工作目录
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
# 创建日志目录
|
||||
mkdir -p logs
|
||||
|
||||
# 定义要启动的服务
|
||||
services=(
|
||||
"yolo.py"
|
||||
"pose.py"
|
||||
"qwenvl_scene.py"
|
||||
"qwenvl_analyze.py"
|
||||
"cpm_scene.py"
|
||||
"cpm_analyze.py"
|
||||
"fall.py"
|
||||
"face.py"
|
||||
"media.py"
|
||||
"compare.py"
|
||||
)
|
||||
|
||||
# 启动所有服务
|
||||
for service in "${services[@]}"; do
|
||||
echo "启动 $service..."
|
||||
# 使用screen创建新的会话并运行Python服务
|
||||
screen_name="${service%.py}"
|
||||
screen -dmS "$screen_name" bash -c "python3 $service > logs/${screen_name}.log 2>&1"
|
||||
# 等待几秒钟,确保服务正常启动
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "所有服务已启动,请检查logs目录下的日志文件"
|
||||
echo "使用 'screen -ls' 查看所有screen会话"
|
||||
echo "使用 'screen -r [会话名]' 连接到特定会话"
|
||||
echo "使用 'cat logs/*.log' 查看日志"
|
||||
Executable
+19
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 获取所有screen会话并关闭它们
|
||||
screen -ls | grep Detached | cut -d. -f1 | awk '{print $1}' | while read pid
|
||||
do
|
||||
echo "正在终止screen会话: $pid"
|
||||
screen -S $pid -X quit
|
||||
done
|
||||
|
||||
echo "所有screen会话已终止"
|
||||
|
||||
# 为了以防万一,也可以直接杀死相关进程
|
||||
for prog in compare media face fall cpm_analyze cpm_scene pose yolo
|
||||
do
|
||||
echo "正在检查并终止 $prog 相关进程"
|
||||
pkill -f $prog
|
||||
done
|
||||
|
||||
echo "清理完成"
|
||||
+37
-31
@@ -8,22 +8,18 @@ import json
|
||||
from kafka import KafkaConsumer
|
||||
import threading
|
||||
import redis
|
||||
from config import *
|
||||
from config import (
|
||||
KAFKA_BROKER, KAFKA_GROUP_ID_PREFIX, WORKER_CONFIGS,
|
||||
REDIS_HOST, REDIS_PORT, REDIS_PASSWORD, MAIN_REDIS_DB,
|
||||
UPLOAD_DIR, RESULT_DIR, YOLO_MODEL_PATH
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = YOLO_MODEL_PATH
|
||||
KAFKA_BROKER = KAFKA_BROKER
|
||||
KAFKA_TOPIC = WORKER_CONFIGS["yolo"]["kafka_topic"]
|
||||
KAFKA_GROUP_ID = f"yolo_{KAFKA_GROUP_ID_PREFIX}"
|
||||
|
||||
REDIS_HOST = REDIS_HOST
|
||||
REDIS_PORT = REDIS_PORT
|
||||
REDIS_PASSWORD = REDIS_PASSWORD
|
||||
REDIS_DB = WORKER_CONFIGS["yolo"]["redis_db"] # Worker使用的Redis DB
|
||||
MAIN_REDIS_DB = MAIN_REDIS_DB # 主Redis DB
|
||||
|
||||
UPLOAD_DIR = UPLOAD_DIR
|
||||
RESULT_DIR = RESULT_DIR
|
||||
|
||||
|
||||
# 确保目录存在
|
||||
@@ -208,15 +204,19 @@ def process_task():
|
||||
result_path = os.path.join(RESULT_DIR, result_filename)
|
||||
cv2.imwrite(result_path, annotated_img)
|
||||
|
||||
redis_client.hmset(f"yolo_result:{task_id}", {
|
||||
"result": json.dumps(json_results),
|
||||
"result_file": result_filename
|
||||
})
|
||||
main_redis_client.hmset(f"task:{task_id}", {
|
||||
"status": "completed",
|
||||
"result_type": "yolo",
|
||||
"result_key": f"yolo_result:{task_id}"
|
||||
})
|
||||
redis_client.hset(f"yolo_result:{task_id}",
|
||||
mapping={
|
||||
"result": json.dumps(json_results),
|
||||
"result_file": result_filename
|
||||
}
|
||||
)
|
||||
main_redis_client.hset(f"task:{task_id}",
|
||||
mapping={
|
||||
"status": "completed",
|
||||
"result_type": "yolo",
|
||||
"result_key": f"yolo_result:{task_id}"
|
||||
}
|
||||
)
|
||||
print(f"图像 {filename} 处理完成,结果已保存")
|
||||
else:
|
||||
print(f"图像 {filename} 处理失败")
|
||||
@@ -226,25 +226,31 @@ def process_task():
|
||||
json_results = process_video(file_path)
|
||||
if json_results:
|
||||
result_filename = f"yolo_{filename}"
|
||||
redis_client.hmset(f"yolo_result:{task_id}", {
|
||||
"result": json.dumps(json_results),
|
||||
"result_file": result_filename
|
||||
})
|
||||
main_redis_client.hmset(f"task:{task_id}", {
|
||||
"status": "completed",
|
||||
"result_type": "yolo",
|
||||
"result_key": f"yolo_result:{task_id}"
|
||||
})
|
||||
redis_client.hset(f"yolo_result:{task_id}",
|
||||
mapping={
|
||||
"result": json.dumps(json_results),
|
||||
"result_file": result_filename
|
||||
}
|
||||
)
|
||||
main_redis_client.hset(f"task:{task_id}",
|
||||
mapping={
|
||||
"status": "completed",
|
||||
"result_type": "yolo",
|
||||
"result_key": f"yolo_result:{task_id}"
|
||||
}
|
||||
)
|
||||
print(f"视频 {filename} 处理完成,结果已保存")
|
||||
else:
|
||||
print(f"视频 {filename} 处理失败")
|
||||
main_redis_client.hset(f"task:{task_id}", "status", "failed")
|
||||
except Exception as e:
|
||||
print(f"处理任务 {task_id} 时出错: {str(e)}")
|
||||
main_redis_client.hmset(f"task:{task_id}", {
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
})
|
||||
main_redis_client.hset(f"task:{task_id}",
|
||||
mapping={
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
|
||||
|
||||
|
||||
+30
-29
@@ -1,28 +1,27 @@
|
||||
# Kafka 配置
|
||||
KAFKA_BROKER=222.186.10.253:9092
|
||||
KAFKA_BROKER=222.186.20.67:9092
|
||||
KAFKA_ASR_TOPIC=asr
|
||||
KAFKA_CHAT_TOPIC=chat
|
||||
KAFKA_TTS_TOPIC=tts
|
||||
|
||||
|
||||
# Redis 配置
|
||||
REDIS_HOST=150.158.144.159
|
||||
REDIS_PORT=13003
|
||||
REDIS_ASR_DB=12
|
||||
REDIS_CHAT_DB=13
|
||||
REDIS_TTS_DB=14
|
||||
REDIS_HOST=222.186.20.67
|
||||
REDIS_PORT=6379
|
||||
REDIS_ASR_DB=43
|
||||
REDIS_CHAT_DB=44
|
||||
REDIS_TTS_DB=45
|
||||
REDIS_PASSWORD=Obscura@2024
|
||||
REDIS_API_DB=2
|
||||
REDIS_API_USAGE_DB=3
|
||||
REDIS_TASK_DB=11
|
||||
REDIS_SESSION_DB=63
|
||||
REDIS_API_DB=31
|
||||
REDIS_API_USAGE_DB=32
|
||||
REDIS_TASK_DB=46
|
||||
REDIS_SESSION_DB=47
|
||||
|
||||
REDIS_SESSION_DB_ZH=48
|
||||
REDIS_SESSION_DB_EN=49
|
||||
REDIS_SESSION_DB_KO=50
|
||||
|
||||
REDIS_SESSION_DB_ZH=63
|
||||
REDIS_SESSION_DB_EN=62
|
||||
REDIS_SESSION_DB_KO=61
|
||||
|
||||
# CORS 配置
|
||||
# ALLOWED_ORIGINS=https://beta.obscura.work
|
||||
|
||||
|
||||
# GPT-SoVITS 配置
|
||||
@@ -76,19 +75,21 @@ XUZHIYUAN_REF_AUDIO=sample/xuzhiyuan.wav
|
||||
XUZHIYUAN_REF_TEXT=sample/xuzhiyuan.txt
|
||||
|
||||
|
||||
REDIS_GIRL_DB = 15
|
||||
REDIS_WOMAN_DB = 16
|
||||
REDIS_MAN_DB = 17
|
||||
REDIS_LEIJUN_DB = 18
|
||||
REDIS_DUFU_DB = 19
|
||||
REDIS_HEJIONG_DB = 20
|
||||
REDIS_MAHUATENG_DB = 21
|
||||
REDIS_LIDAN_DB = 22
|
||||
REDIS_DABING_DB = 23
|
||||
REDIS_LUOXIANG_DB = 24
|
||||
REDIS_XUZHIYUAN_DB = 25
|
||||
REDIS_YUHUA_DB = 26
|
||||
REDIS_LIUZHENYUN_DB = 27
|
||||
|
||||
REDIS_GIRL_DB = 51
|
||||
REDIS_WOMAN_DB = 52
|
||||
REDIS_MAN_DB = 53
|
||||
REDIS_LEIJUN_DB = 54
|
||||
REDIS_DUFU_DB = 55
|
||||
REDIS_HEJIONG_DB = 56
|
||||
REDIS_MAHUATENG_DB = 57
|
||||
REDIS_LIDAN_DB = 58
|
||||
REDIS_DABING_DB = 59
|
||||
REDIS_LUOXIANG_DB = 60
|
||||
REDIS_XUZHIYUAN_DB = 61
|
||||
REDIS_YUHUA_DB = 62
|
||||
REDIS_LIUZHENYUN_DB = 63
|
||||
|
||||
|
||||
# Ollama API配置 - 多个地址用逗号分隔
|
||||
OLLAMA_URLS=http://222.186.20.67:11435,http://222.186.20.67:11436,http://222.186.20.67:11437,http://222.186.20.67:11438,http://222.186.20.67:11439,http://222.186.20.67:11440,http://222.186.20.67:11441
|
||||
OLLAMA_TIMEOUT=10 # API请求超时时间(秒)
|
||||
Regular → Executable
+4
-1
@@ -6,6 +6,9 @@ from dotenv import load_dotenv
|
||||
from kafka import KafkaConsumer
|
||||
import asyncio
|
||||
|
||||
# 在导入其他库之前设置
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# 设置要使用的GPU ID
|
||||
GPU_ID = 1 # 修改这个值来选择要使用的GPU
|
||||
|
||||
@@ -16,7 +19,7 @@ os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
|
||||
load_dotenv()
|
||||
|
||||
print("正在加载Whisper模型...")
|
||||
model = whisper.load_model("large-v3")
|
||||
model = whisper.load_model("small")
|
||||
print("Whisper模型加载完成。")
|
||||
|
||||
# Kafka配置
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
import whisper
|
||||
import os
|
||||
import json
|
||||
import redis
|
||||
from dotenv import load_dotenv
|
||||
from kafka import KafkaConsumer
|
||||
import threading
|
||||
|
||||
# 设置要使用的GPU ID
|
||||
GPU_ID = 1 # 修改这个值来选择要使用的GPU
|
||||
|
||||
# 设置CUDA_VISIBLE_DEVICES环境变量
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
print("正在加载Whisper模型...")
|
||||
model = whisper.load_model("large-v3")
|
||||
print("Whisper模型加载完成。")
|
||||
|
||||
# Kafka配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
||||
|
||||
# Redis配置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_asr_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_ASR_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
redis_task_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_TASK_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
def process_audio(file_path: str, client_id: str, cache_key: str):
|
||||
try:
|
||||
# 设置任务状态为 "processing"
|
||||
redis_task_client.set(f"task_status:{cache_key}", "processing")
|
||||
|
||||
result = model.transcribe(file_path)
|
||||
transcription = result['text']
|
||||
|
||||
print(f"处理了文件: {file_path}")
|
||||
print(f"转录结果: {transcription}")
|
||||
|
||||
# 将结果存入Redis缓存
|
||||
redis_asr_client.setex(cache_key, 3600, transcription) # 缓存1小时
|
||||
|
||||
# 发布结果到Redis频道
|
||||
result_data = {
|
||||
'client_id': client_id,
|
||||
'transcription': transcription
|
||||
}
|
||||
redis_asr_client.publish('asr_results', json.dumps(result_data))
|
||||
|
||||
# 设置任务状态为 "completed"
|
||||
redis_task_client.set(f"task_status:{cache_key}", "completed")
|
||||
|
||||
# 清理临时文件
|
||||
os.remove(file_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理音频文件时发生错误: {str(e)}")
|
||||
# 设置任务状态为 "error"
|
||||
redis_task_client.set(f"task_status:{cache_key}", "error")
|
||||
|
||||
def kafka_consumer():
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
|
||||
group_id='asr_group',
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True
|
||||
)
|
||||
|
||||
print(f"ASR消费者已启动")
|
||||
|
||||
for message in consumer:
|
||||
try:
|
||||
task = message.value
|
||||
file_path = task.get('file_path')
|
||||
task_id = task.get('task_id')
|
||||
status = task.get('status')
|
||||
|
||||
if not file_path or not task_id or status != 'queued':
|
||||
print(f"收到无效任务: {task}")
|
||||
continue
|
||||
|
||||
cache_key = f"asr:{task_id}"
|
||||
client_id = task_id # 使用task_id作为client_id
|
||||
|
||||
print(f"开始处理任务: {cache_key}")
|
||||
process_audio(file_path, client_id, cache_key)
|
||||
print(f"完成处理任务: {cache_key}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理消息时发生错误: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("启动Kafka消费者处理ASR请求...")
|
||||
kafka_consumer()
|
||||
@@ -1,117 +0,0 @@
|
||||
from kafka import KafkaConsumer
|
||||
import json
|
||||
import asyncio
|
||||
import redis
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import requests
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
|
||||
# Kafka 设置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
|
||||
KAFKA_CONSUMER_GROUP = 'chat_group'
|
||||
KAFKA_CONSUMER_NUM = 3 # 消费者数量
|
||||
|
||||
# Redis 设置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_CHAT_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
redis_task_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_TASK_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
|
||||
|
||||
# 创建Kafka消费者
|
||||
def create_kafka_consumer():
|
||||
return KafkaConsumer(
|
||||
KAFKA_CHAT_TOPIC,
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
auto_offset_reset='latest',
|
||||
enable_auto_commit=True,
|
||||
group_id=KAFKA_CONSUMER_GROUP,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
async def process_chat_request(chat_request):
|
||||
try:
|
||||
session_id = chat_request['session_id']
|
||||
query = chat_request['query']
|
||||
model = chat_request.get('model', 'qwen2.5:3b')
|
||||
|
||||
# 设置任务状态为 "processing"
|
||||
redis_task_client.set(f"task_status:{session_id}", "processing")
|
||||
|
||||
# 从Redis获取历史记录
|
||||
history = json.loads(redis_client.get(session_id) or '[]')
|
||||
|
||||
# 构建包含历史对话的完整提示
|
||||
full_prompt = DEFAULT_SYSTEM_PROMPT + "\n"
|
||||
for past_query, past_response in history:
|
||||
full_prompt += f"用户: {past_query}\n助手: {past_response}\n"
|
||||
full_prompt += f"用户: {query}\n助手:"
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": full_prompt,
|
||||
"stream": True,
|
||||
"temperature": 0
|
||||
}
|
||||
|
||||
response = requests.post("http://127.0.0.1:11434/api/generate", json=data, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
text_output = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
json_data = json.loads(line)
|
||||
if 'response' in json_data:
|
||||
text_output += json_data['response']
|
||||
|
||||
# 更新历史记录
|
||||
history.append((query, text_output))
|
||||
redis_client.set(session_id, json.dumps(history))
|
||||
|
||||
# 设置任务状态为 "completed"
|
||||
redis_task_client.set(f"task_status:{session_id}", "completed")
|
||||
|
||||
print(f"处理完成 session {session_id}: {text_output}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理 session {chat_request['session_id']} 时出错: {str(e)}")
|
||||
# 设置任务状态为 "error"
|
||||
redis_task_client.set(f"task_status:{chat_request['session_id']}", "error")
|
||||
|
||||
def kafka_consumer_thread(consumer_id):
|
||||
consumer = create_kafka_consumer()
|
||||
print(f"消费者 {consumer_id} 已启动")
|
||||
for message in consumer:
|
||||
chat_request = message.value
|
||||
asyncio.run(process_chat_request(chat_request))
|
||||
|
||||
def main():
|
||||
print("启动Kafka消费者处理聊天请求...")
|
||||
with ThreadPoolExecutor(max_workers=KAFKA_CONSUMER_NUM) as executor:
|
||||
for i in range(KAFKA_CONSUMER_NUM):
|
||||
executor.submit(kafka_consumer_thread, i)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,182 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import requests
|
||||
import json
|
||||
from typing import List, Tuple
|
||||
from kafka import KafkaConsumer, TopicPartition
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
import asyncio
|
||||
import redis
|
||||
import uuid
|
||||
import logging
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import torch
|
||||
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
app = FastAPI()
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
torch.cuda.set_device(device)
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load MiniCPM3-4B model
|
||||
path = "/home/zydi/worker_chat/api/OpenBMB/MiniCPM3-4B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True)
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
# CORS 配置
|
||||
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Kafka 设置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_MINI3_TOPIC')
|
||||
KAFKA_CONSUMER_GROUP = 'mini3_group'
|
||||
KAFKA_CONSUMER_NUM = 1
|
||||
|
||||
# Redis 设置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_DB = int(os.getenv('REDIS_MINI3_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB,
|
||||
password=REDIS_PASSWORD # 使用密码进行认证
|
||||
)
|
||||
# 创建Kafka消费者
|
||||
def create_kafka_consumer():
|
||||
return KafkaConsumer(
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
group_id=KAFKA_CONSUMER_GROUP,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# Kafka消费者函数
|
||||
def kafka_consumer(consumer, consumer_id):
|
||||
# 获取消费者分配的分区
|
||||
consumer.subscribe([KAFKA_TOPIC])
|
||||
partitions = consumer.assignment()
|
||||
|
||||
logger.info(f"消费者 {consumer_id} 被分配了以下分区: {[p.partition for p in partitions]}")
|
||||
|
||||
for message in consumer:
|
||||
partition = message.partition
|
||||
offset = message.offset
|
||||
chat_request = message.value # 直接使用 message.value,它已经是一个字典
|
||||
session_id = chat_request['session_id']
|
||||
query = chat_request['query']
|
||||
|
||||
logger.info(f"消费者 {consumer_id} 正在处理来自分区 {partition} 的消息:")
|
||||
|
||||
asyncio.run(process_chat_request(chat_request))
|
||||
|
||||
# 启动Kafka消费者线程
|
||||
def start_kafka_consumers(num_consumers=KAFKA_CONSUMER_NUM):
|
||||
consumers = []
|
||||
for i in range(num_consumers):
|
||||
consumer = create_kafka_consumer()
|
||||
consumer_thread = threading.Thread(target=kafka_consumer, args=(consumer, i), daemon=True)
|
||||
consumer_thread.start()
|
||||
consumers.append((consumer, consumer_thread))
|
||||
return consumers
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str
|
||||
query: str
|
||||
model: str = "minicpm3-4b"
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
response: str
|
||||
history: List[Tuple[str, str]]
|
||||
|
||||
# 处理聊天请求的异步函数
|
||||
async def process_chat_request(chat_request):
|
||||
try:
|
||||
response = await chat(ChatRequest(**chat_request))
|
||||
print(f"Processed message for session {chat_request['session_id']}: {response}")
|
||||
except Exception as e:
|
||||
print(f"Error processing message for session {chat_request['session_id']}: {str(e)}")
|
||||
|
||||
@app.post("/mini3", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest):
|
||||
session_id = request.session_id
|
||||
query = request.query
|
||||
|
||||
# 从Redis获取历史记录
|
||||
history = json.loads(redis_client.get(session_id) or '[]')
|
||||
|
||||
# 构建包含历史对话的完整提示
|
||||
messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
|
||||
for past_query, past_response in history:
|
||||
messages.append({"role": "user", "content": past_query})
|
||||
messages.append({"role": "assistant", "content": past_response})
|
||||
messages.append({"role": "user", "content": query})
|
||||
|
||||
try:
|
||||
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
|
||||
|
||||
# 创建注意力掩码
|
||||
attention_mask = model_inputs.ne(tokenizer.pad_token_id).long()
|
||||
|
||||
# 将输入移动到正确的设备(CPU或GPU)
|
||||
model_inputs = model_inputs.to(device)
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
model_outputs = model.generate(
|
||||
model_inputs,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=1024,
|
||||
top_p=0.7,
|
||||
temperature=0.7,
|
||||
pad_token_id=tokenizer.eos_token_id, # 将pad_token_id设置为eos_token_id
|
||||
do_sample=True
|
||||
)
|
||||
|
||||
output_token_ids = model_outputs[0][len(model_inputs[0]):]
|
||||
text_output = tokenizer.decode(output_token_ids, skip_special_tokens=True)
|
||||
|
||||
# 更新历史记录
|
||||
history.append((query, text_output))
|
||||
redis_client.set(session_id, json.dumps(history))
|
||||
|
||||
return ChatResponse(response=text_output, history=history)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/start_chat")
|
||||
async def start_chat():
|
||||
session_id = str(uuid.uuid4())
|
||||
return {"session_id": session_id}
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 启动Kafka消费者线程
|
||||
start_kafka_consumers()
|
||||
|
||||
# 启动FastAPI服务器
|
||||
uvicorn.run(app, host="0.0.0.0", port=6003)
|
||||
@@ -1,63 +0,0 @@
|
||||
import os
|
||||
from moviepy.editor import VideoFileClip
|
||||
|
||||
def mp4_to_wav(input_file, output_file):
|
||||
"""
|
||||
将MP4文件转换为WAV格式
|
||||
|
||||
:param input_file: 输入的MP4文件路径
|
||||
:param output_file: 输出的WAV文件路径
|
||||
"""
|
||||
try:
|
||||
# 加载视频文件
|
||||
video = VideoFileClip(input_file)
|
||||
|
||||
# 提取音频
|
||||
audio = video.audio
|
||||
|
||||
# 将音频写入WAV文件
|
||||
audio.write_audiofile(output_file)
|
||||
|
||||
# 关闭视频和音频对象
|
||||
audio.close()
|
||||
video.close()
|
||||
|
||||
print(f"转换成功: {input_file} -> {output_file}")
|
||||
except Exception as e:
|
||||
print(f"转换失败: {input_file} - {str(e)}")
|
||||
|
||||
def process_directory(directory):
|
||||
"""
|
||||
处理目录中的所有MP4文件
|
||||
|
||||
:param directory: 包含MP4文件的目录路径
|
||||
"""
|
||||
for filename in os.listdir(directory):
|
||||
if filename.lower().endswith('.mp4'):
|
||||
input_file = os.path.join(directory, filename)
|
||||
output_file = os.path.splitext(input_file)[0] + ".wav"
|
||||
mp4_to_wav(input_file, output_file)
|
||||
|
||||
def main():
|
||||
# 获取输入路径
|
||||
input_path = input("请输入MP4文件或包含MP4文件的目录路径: ").strip()
|
||||
|
||||
# 检查输入路径是否存在
|
||||
if not os.path.exists(input_path):
|
||||
print("错误: 输入路径不存在")
|
||||
return
|
||||
|
||||
# 判断输入路径是文件还是目录
|
||||
if os.path.isfile(input_path):
|
||||
if not input_path.lower().endswith('.mp4'):
|
||||
print("错误: 输入文件不是MP4格式")
|
||||
return
|
||||
output_file = os.path.splitext(input_path)[0] + ".wav"
|
||||
mp4_to_wav(input_path, output_file)
|
||||
elif os.path.isdir(input_path):
|
||||
process_directory(input_path)
|
||||
else:
|
||||
print("错误: 输入路径既不是文件也不是目录")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,170 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import requests
|
||||
import json
|
||||
from typing import List, Tuple
|
||||
from kafka import KafkaConsumer, TopicPartition
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
import asyncio
|
||||
import redis
|
||||
import uuid
|
||||
import logging
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import torch
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
app = FastAPI()
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
torch.cuda.set_device(device)
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
# CORS 配置
|
||||
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Kafka 设置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
|
||||
KAFKA_CONSUMER_GROUP = 'chat_group'
|
||||
KAFKA_CONSUMER_NUM = 1
|
||||
|
||||
# Redis 设置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_DB = int(os.getenv('REDIS_CHAT_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB,
|
||||
password=REDIS_PASSWORD # 使用密码进行认证
|
||||
)
|
||||
# 创建Kafka消费者
|
||||
def create_kafka_consumer():
|
||||
return KafkaConsumer(
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
group_id=KAFKA_CONSUMER_GROUP,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# Kafka消费者函数
|
||||
def kafka_consumer(consumer, consumer_id):
|
||||
# 获取消费者分配的分区
|
||||
consumer.subscribe([KAFKA_TOPIC])
|
||||
partitions = consumer.assignment()
|
||||
|
||||
logger.info(f"消费者 {consumer_id} 被分配了以下分区: {[p.partition for p in partitions]}")
|
||||
|
||||
for message in consumer:
|
||||
partition = message.partition
|
||||
offset = message.offset
|
||||
chat_request = message.value # 直接使用 message.value,它已经是一个字典
|
||||
session_id = chat_request['session_id']
|
||||
query = chat_request['query']
|
||||
|
||||
logger.info(f"消费者 {consumer_id} 正在处理来自分区 {partition} 的消息:")
|
||||
|
||||
asyncio.run(process_chat_request(chat_request))
|
||||
|
||||
# 启动Kafka消费者线程
|
||||
def start_kafka_consumers(num_consumers=KAFKA_CONSUMER_NUM):
|
||||
consumers = []
|
||||
for i in range(num_consumers):
|
||||
consumer = create_kafka_consumer()
|
||||
consumer_thread = threading.Thread(target=kafka_consumer, args=(consumer, i), daemon=True)
|
||||
consumer_thread.start()
|
||||
consumers.append((consumer, consumer_thread))
|
||||
return consumers
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str
|
||||
query: str
|
||||
model: str = "qwen2.5:3b"
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
response: str
|
||||
history: List[Tuple[str, str]]
|
||||
|
||||
# 处理聊天请求的异步函数
|
||||
async def process_chat_request(chat_request):
|
||||
try:
|
||||
response = await chat(ChatRequest(**chat_request))
|
||||
print(f"Processed message for session {chat_request['session_id']}: {response}")
|
||||
except Exception as e:
|
||||
print(f"Error processing message for session {chat_request['session_id']}: {str(e)}")
|
||||
|
||||
@app.post("/chat", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest):
|
||||
session_id = request.session_id
|
||||
query = request.query
|
||||
model = request.model
|
||||
|
||||
# 从Redis获取历史记录
|
||||
history = json.loads(redis_client.get(session_id) or '[]')
|
||||
|
||||
# 构建包含历史对话的完整提示
|
||||
full_prompt = DEFAULT_SYSTEM_PROMPT + "\n"
|
||||
for past_query, past_response in history:
|
||||
full_prompt += f"用户: {past_query}\n助手: {past_response}\n"
|
||||
full_prompt += f"用户: {query}"
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": full_prompt,
|
||||
"stream": True,
|
||||
"temperature": 0
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post("http://127.0.0.1:11434/api/generate", json=data, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
text_output = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
json_data = json.loads(line)
|
||||
if 'response' in json_data:
|
||||
text_output += json_data['response']
|
||||
|
||||
# 更新历史记录
|
||||
history.append((query, text_output))
|
||||
redis_client.set(session_id, json.dumps(history))
|
||||
|
||||
return ChatResponse(response=text_output, history=history)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/start_chat")
|
||||
async def start_chat():
|
||||
session_id = str(uuid.uuid4())
|
||||
return {"session_id": session_id}
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 启动Kafka消费者线程
|
||||
start_kafka_consumers()
|
||||
|
||||
# 启动FastAPI服务器
|
||||
uvicorn.run(app, host="0.0.0.0", port=6001)
|
||||
@@ -1,136 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import httpx
|
||||
import json
|
||||
import redis
|
||||
from typing import List, Dict, Optional
|
||||
import logging
|
||||
import ollama
|
||||
import uuid
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Redis连接
|
||||
redis_client = redis.Redis(host='222.186.10.253', port=6379, db=14, password="Obscura@2024")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
model: Optional[str] = "qwen2.5:3b"
|
||||
prompt: str
|
||||
|
||||
class RawGenerateRequest(BaseModel):
|
||||
model: Optional[str] = "qwen2.5:3b"
|
||||
prompt: str
|
||||
system_prompt: Optional[str] = None
|
||||
stream: Optional[bool] = False
|
||||
raw: Optional[bool] = False
|
||||
format: Optional[str] = None
|
||||
options: Optional[Dict] = None
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
response: dict
|
||||
request_id: str
|
||||
|
||||
@app.post("/generate", response_model=GenerateResponse)
|
||||
async def generate(request: GenerateRequest):
|
||||
logger.info(f"收到请求: {request}")
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
response = ollama.chat(model=request.model, messages=[{"role": "user", "content": request.prompt}])
|
||||
full_response = response['message']['content']
|
||||
|
||||
request_data = {
|
||||
"model": request.model,
|
||||
"prompt": request.prompt,
|
||||
"response": full_response
|
||||
}
|
||||
|
||||
redis_client.set(f"request:{request_id}", json.dumps(request_data))
|
||||
|
||||
response_data = {
|
||||
"response": full_response,
|
||||
"model": request.model
|
||||
}
|
||||
|
||||
return GenerateResponse(response=response_data, request_id=request_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发生错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/generate")
|
||||
async def generate_without_history(request: RawGenerateRequest):
|
||||
"""
|
||||
处理无历史记录的生成请求。
|
||||
|
||||
参数:
|
||||
- request: RawGenerateRequest对象,包含生成请求的所有参数。
|
||||
|
||||
返回:
|
||||
- 包含生成结果的字典。
|
||||
"""
|
||||
try:
|
||||
response = ollama.generate(
|
||||
model=request.model,
|
||||
prompt=request.prompt,
|
||||
system=request.system_prompt,
|
||||
format=request.format,
|
||||
options=request.options,
|
||||
stream=request.stream
|
||||
)
|
||||
|
||||
response_data = {
|
||||
"model": request.model,
|
||||
"response": response['response'],
|
||||
"done": True,
|
||||
"context": response.get('context'),
|
||||
"total_duration": response.get('total_duration'),
|
||||
"load_duration": response.get('load_duration'),
|
||||
"prompt_eval_count": response.get('prompt_eval_count'),
|
||||
"prompt_eval_duration": response.get('prompt_eval_duration'),
|
||||
"eval_count": response.get('eval_count'),
|
||||
"eval_duration": response.get('eval_duration')
|
||||
}
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
redis_client.set(f"request:{request_id}", json.dumps(response_data))
|
||||
|
||||
return response_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发生未预期的错误: {e}")
|
||||
logger.exception("详细错误信息:")
|
||||
raise HTTPException(status_code=500, detail=f"处理Ollama请求时发生错误: {str(e)}")
|
||||
|
||||
@app.get("/request/{request_id}", response_model=Dict)
|
||||
async def get_request(request_id: str):
|
||||
request_data = redis_client.get(f"request:{request_id}")
|
||||
if request_data:
|
||||
return json.loads(request_data)
|
||||
raise HTTPException(status_code=404, detail="请求未找到")
|
||||
|
||||
@app.get("/models")
|
||||
async def list_models():
|
||||
return ollama.list()
|
||||
|
||||
@app.get("/models/{model_name}")
|
||||
async def show_model(model_name: str):
|
||||
return ollama.show(model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=7000)
|
||||
@@ -1,319 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security import APIKeyHeader
|
||||
from pydantic import BaseModel
|
||||
from kafka import KafkaProducer
|
||||
from redis import Redis
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from dotenv import load_dotenv
|
||||
import tempfile
|
||||
import hashlib
|
||||
import asyncio
|
||||
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI()
|
||||
v1_chat_app = FastAPI()
|
||||
app.mount("/v1_chat", v1_chat_app)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB'))
|
||||
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
|
||||
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
|
||||
REDIS_API_DB = int(os.getenv('REDIS_API_DB'))
|
||||
REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB'))
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||
|
||||
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||
KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
||||
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
|
||||
|
||||
# 初始化 Kafka Producer
|
||||
producer = KafkaProducer(
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB)
|
||||
redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB)
|
||||
redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB)
|
||||
redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB)
|
||||
redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB)
|
||||
redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB)
|
||||
|
||||
# 定义API密钥头部
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
def get_audio_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# 验证API密钥的函数
|
||||
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||
if api_key and api_key.startswith("Bearer "):
|
||||
key = api_key.split(" ")[1]
|
||||
if key.startswith("obs-"):
|
||||
return key
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的API密钥",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
"api_key": api_key,
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str):
|
||||
api_key = api_key_info['api_key']
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
# 发送任务到Kafka
|
||||
producer.send(kafka_topic, task_data)
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return {
|
||||
"message": f"{model_name.upper()}请求已排队等待处理",
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
}
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str
|
||||
query: str
|
||||
model: str = "qwen2.5:3b"
|
||||
|
||||
|
||||
# 添加WebSocket连接管理
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, client_id: str):
|
||||
await websocket.accept()
|
||||
self.active_connections[client_id] = websocket
|
||||
|
||||
def disconnect(self, client_id: str):
|
||||
self.active_connections.pop(client_id, None)
|
||||
|
||||
async def send_message(self, message: str, client_id: str):
|
||||
if client_id in self.active_connections:
|
||||
await self.active_connections[client_id].send_text(message)
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
@v1_chat_app.websocket("/ws/{client_id}")
|
||||
async def websocket_endpoint(websocket: WebSocket, client_id: str):
|
||||
await manager.connect(websocket, client_id)
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(client_id)
|
||||
|
||||
# 修改TTS请求处理函数
|
||||
@v1_chat_app.post("/tts")
|
||||
async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
task_data = {
|
||||
"task_id": task_id,
|
||||
"text": request.text,
|
||||
"status": "queued",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
redis_task_client.set(f"task_status:{task_id}", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "tts", 100, task_data, KAFKA_TTS_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
|
||||
# 将任务ID存储到Redis,以便后续WebSocket通信使用
|
||||
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
|
||||
|
||||
return result
|
||||
|
||||
# 修改ASR请求处理函数
|
||||
@v1_chat_app.post("/asr")
|
||||
async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
UPLOAD_DIR = "/obscura/task/audio_upload"
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav")
|
||||
|
||||
with open(file_path, "wb") as temp_audio:
|
||||
content = await audio.read()
|
||||
temp_audio.write(content)
|
||||
|
||||
task_data = {
|
||||
'file_path': file_path,
|
||||
'task_id': task_id,
|
||||
'status': 'queued'
|
||||
}
|
||||
|
||||
redis_task_client.set(f"task_status:{task_id}", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "asr", 100, task_data, KAFKA_ASR_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
|
||||
# 将任务ID存储到Redis,以便后续WebSocket通信使用
|
||||
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
|
||||
|
||||
return result
|
||||
|
||||
# 修改聊天请求处理函数
|
||||
@v1_chat_app.post("/chat")
|
||||
async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
task_data = {
|
||||
"task_id": task_id,
|
||||
"session_id": request.session_id,
|
||||
"query": request.query,
|
||||
"model": request.model,
|
||||
"status": "queued",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
redis_task_client.set(f"task_status:{task_id}", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "chat", 100, task_data, KAFKA_CHAT_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
|
||||
# 将任务ID存储到Redis,以便后续WebSocket通信使用
|
||||
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
|
||||
|
||||
return result
|
||||
|
||||
@v1_chat_app.get("/chat_result/{task_id}")
|
||||
async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
# 从Redis任务数据库获取任务状态
|
||||
task_status = redis_task_client.get(f"task_status:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从Redis聊天结果数据库获取聊天结果
|
||||
chat_result = redis_chat_client.get(task_id)
|
||||
if chat_result:
|
||||
result = json.loads(chat_result)
|
||||
return {
|
||||
"status": "completed",
|
||||
"history": result # 直接返回整个历史记录
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
@v1_chat_app.get("/tts_result/{task_id}")
|
||||
async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
# 从Redis任务数据库获取任务状态
|
||||
task_status = redis_task_client.get(f"task_status:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从Redis TTS结果数据库获取音频文件路径
|
||||
audio_info = redis_tts_client.get(task_id)
|
||||
if audio_info:
|
||||
audio_path = json.loads(audio_info)['path']
|
||||
return {
|
||||
"status": "completed",
|
||||
"audio_path": audio_path
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
@v1_chat_app.get("/asr_result/{task_id}")
|
||||
async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
# 从Redis任务数据库获取任务状态
|
||||
task_status = redis_task_client.get(f"task_status:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从Redis ASR结果数据库获取转录结果
|
||||
transcription = redis_asr_client.get(task_id)
|
||||
if transcription:
|
||||
return {
|
||||
"status": "completed",
|
||||
"transcription": transcription.decode('utf-8')
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8008)
|
||||
@@ -1,180 +0,0 @@
|
||||
import os
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
import uvicorn
|
||||
import redis
|
||||
import hashlib
|
||||
import json
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
import threading
|
||||
import time
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import torch
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# FastAPI configuration
|
||||
app = FastAPI()
|
||||
i18n = I18nAuto()
|
||||
|
||||
# CORS configuration
|
||||
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
||||
|
||||
# Redis configuration
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_DB = int(os.getenv('REDIS_TTS_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# Kafka configuration
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||
# KAFKA_GROUP_ID = 'tts_group'
|
||||
KAFKA_CONSUMER_THREADS = 1
|
||||
|
||||
# TTS configuration
|
||||
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
|
||||
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
|
||||
REF_AUDIO_PATH = os.getenv('REF_AUDIO_PATH')
|
||||
REF_TEXT_PATH = os.getenv('REF_TEXT_PATH')
|
||||
REF_LANGUAGE = os.getenv('REF_LANGUAGE')
|
||||
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
|
||||
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
|
||||
|
||||
# Initialize FastAPI CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize Redis client
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
# Initialize Kafka producer
|
||||
kafka_producer = KafkaProducer(bootstrap_servers=KAFKA_BROKER)
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str = Field(..., alias="text")
|
||||
|
||||
def get_audio_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# Initialize models at startup
|
||||
print("Initializing models...")
|
||||
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
|
||||
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
|
||||
|
||||
# Read reference text
|
||||
with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file:
|
||||
ref_text = file.read()
|
||||
|
||||
print("Models initialized successfully.")
|
||||
|
||||
def synthesize(target_text, output_path):
|
||||
# Synthesize audio
|
||||
with torch.cuda.device(device):
|
||||
synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(REF_LANGUAGE),
|
||||
text=target_text,
|
||||
text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
audio_hash = get_audio_hash(target_text)
|
||||
output_wav_path = os.path.join(output_path, f"{audio_hash}.wav")
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
return output_wav_path
|
||||
else:
|
||||
return None
|
||||
|
||||
@app.post("/tts")
|
||||
async def synthesize_audio(request: TTSRequest):
|
||||
try:
|
||||
print(f"Received TTS request: {request.dict()}")
|
||||
target_text = request.text
|
||||
audio_hash = get_audio_hash(target_text)
|
||||
|
||||
# Check Redis cache
|
||||
cached_audio = redis_client.get(audio_hash)
|
||||
if cached_audio:
|
||||
audio_info = json.loads(cached_audio)
|
||||
return FileResponse(audio_info['path'], media_type="audio/wav")
|
||||
|
||||
# Check file system
|
||||
file_path = os.path.join(OUTPUT_PATH, f"{audio_hash}.wav")
|
||||
if os.path.exists(file_path):
|
||||
# Cache the file path in Redis
|
||||
redis_client.set(audio_hash, json.dumps({"path": file_path}))
|
||||
return FileResponse(file_path, media_type="audio/wav")
|
||||
|
||||
# Send message to Kafka
|
||||
kafka_producer.send(KAFKA_TOPIC, json.dumps({
|
||||
'text': target_text,
|
||||
'audio_hash': audio_hash
|
||||
}).encode('utf-8'))
|
||||
|
||||
# Wait for the audio to be generated (you might want to implement a more sophisticated waiting mechanism)
|
||||
for _ in range(60): # Wait for up to 30 seconds
|
||||
if os.path.exists(file_path):
|
||||
return FileResponse(file_path, media_type="audio/wav")
|
||||
time.sleep(1)
|
||||
|
||||
# If audio is not generated within the timeout
|
||||
raise HTTPException(status_code=504, detail="Audio generation timed out")
|
||||
except Exception as e:
|
||||
print(f"Error processing TTS request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "TTS API is running"}
|
||||
|
||||
def kafka_consumer_thread():
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
# group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='latest',
|
||||
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
|
||||
)
|
||||
|
||||
for message in consumer:
|
||||
target_text = message.value['text']
|
||||
audio_hash = message.value['audio_hash']
|
||||
|
||||
output_path = synthesize(target_text, OUTPUT_PATH)
|
||||
|
||||
if output_path:
|
||||
redis_client.set(audio_hash, json.dumps({"path": output_path}))
|
||||
print(f"Audio synthesized successfully: {output_path}")
|
||||
else:
|
||||
print("Failed to synthesize audio")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Start Kafka consumer threads
|
||||
torch.cuda.set_device(device)
|
||||
for _ in range(KAFKA_CONSUMER_THREADS):
|
||||
consumer_thread = threading.Thread(target=kafka_consumer_thread)
|
||||
consumer_thread.start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=6002)
|
||||
@@ -1,180 +0,0 @@
|
||||
import os
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
import uvicorn
|
||||
import redis
|
||||
import hashlib
|
||||
import json
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
import threading
|
||||
import time
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import torch
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# FastAPI configuration
|
||||
app = FastAPI()
|
||||
i18n = I18nAuto()
|
||||
|
||||
# CORS configuration
|
||||
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
||||
|
||||
# Redis configuration
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_DB = int(os.getenv('REDIS_TTS_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# Kafka configuration
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||
# KAFKA_GROUP_ID = 'tts_group'
|
||||
KAFKA_CONSUMER_THREADS = 1
|
||||
|
||||
# TTS configuration
|
||||
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
|
||||
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
|
||||
REF_AUDIO_PATH = os.getenv('REF_AUDIO_KO_PATH')
|
||||
REF_TEXT_PATH = os.getenv('REF_TEXT_KO_PATH')
|
||||
REF_LANGUAGE = os.getenv('REF_KO_LANGUAGE')
|
||||
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
|
||||
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
|
||||
|
||||
# Initialize FastAPI CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize Redis client
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
# Initialize Kafka producer
|
||||
kafka_producer = KafkaProducer(bootstrap_servers=KAFKA_BROKER)
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str = Field(..., alias="text")
|
||||
|
||||
def get_audio_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# Initialize models at startup
|
||||
print("Initializing models...")
|
||||
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
|
||||
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
|
||||
|
||||
# Read reference text
|
||||
with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file:
|
||||
ref_text = file.read()
|
||||
|
||||
print("Models initialized successfully.")
|
||||
|
||||
def synthesize(target_text, output_path):
|
||||
# Synthesize audio
|
||||
with torch.cuda.device(device):
|
||||
synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(REF_LANGUAGE),
|
||||
text=target_text,
|
||||
text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
audio_hash = get_audio_hash(target_text)
|
||||
output_wav_path = os.path.join(output_path, f"{audio_hash}.wav")
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
return output_wav_path
|
||||
else:
|
||||
return None
|
||||
|
||||
@app.post("/tts_ko")
|
||||
async def synthesize_audio(request: TTSRequest):
|
||||
try:
|
||||
print(f"Received TTS request: {request.dict()}")
|
||||
target_text = request.text
|
||||
audio_hash = get_audio_hash(target_text)
|
||||
|
||||
# Check Redis cache
|
||||
cached_audio = redis_client.get(audio_hash)
|
||||
if cached_audio:
|
||||
audio_info = json.loads(cached_audio)
|
||||
return FileResponse(audio_info['path'], media_type="audio/wav")
|
||||
|
||||
# Check file system
|
||||
file_path = os.path.join(OUTPUT_PATH, f"{audio_hash}.wav")
|
||||
if os.path.exists(file_path):
|
||||
# Cache the file path in Redis
|
||||
redis_client.set(audio_hash, json.dumps({"path": file_path}))
|
||||
return FileResponse(file_path, media_type="audio/wav")
|
||||
|
||||
# Send message to Kafka
|
||||
kafka_producer.send(KAFKA_TOPIC, json.dumps({
|
||||
'text': target_text,
|
||||
'audio_hash': audio_hash
|
||||
}).encode('utf-8'))
|
||||
|
||||
# Wait for the audio to be generated (you might want to implement a more sophisticated waiting mechanism)
|
||||
for _ in range(60): # Wait for up to 30 seconds
|
||||
if os.path.exists(file_path):
|
||||
return FileResponse(file_path, media_type="audio/wav")
|
||||
time.sleep(1)
|
||||
|
||||
# If audio is not generated within the timeout
|
||||
raise HTTPException(status_code=504, detail="Audio generation timed out")
|
||||
except Exception as e:
|
||||
print(f"Error processing TTS request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "TTS API is running"}
|
||||
|
||||
def kafka_consumer_thread():
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
# group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='latest',
|
||||
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
|
||||
)
|
||||
|
||||
for message in consumer:
|
||||
target_text = message.value['text']
|
||||
audio_hash = message.value['audio_hash']
|
||||
|
||||
output_path = synthesize(target_text, OUTPUT_PATH)
|
||||
|
||||
if output_path:
|
||||
redis_client.set(audio_hash, json.dumps({"path": output_path}))
|
||||
print(f"Audio synthesized successfully: {output_path}")
|
||||
else:
|
||||
print("Failed to synthesize audio")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Start Kafka consumer threads
|
||||
torch.cuda.set_device(device)
|
||||
for _ in range(KAFKA_CONSUMER_THREADS):
|
||||
consumer_thread = threading.Thread(target=kafka_consumer_thread)
|
||||
consumer_thread.start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=6003)
|
||||
@@ -1,176 +0,0 @@
|
||||
# 导入所需的库
|
||||
import os
|
||||
import soundfile as sf
|
||||
import redis
|
||||
import hashlib
|
||||
import json
|
||||
from kafka import KafkaConsumer
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
from dotenv import load_dotenv
|
||||
import torch
|
||||
|
||||
"""
|
||||
整体设计说明:
|
||||
这个脚本实现了一个文本到语音(TTS)的服务。它使用Kafka作为消息队列接收TTS任务,
|
||||
使用Redis存储任务状态和结果,并利用GPT-SoVITS模型进行语音合成。
|
||||
主要功能包括:
|
||||
1. 初始化配置和模型
|
||||
2. 提供语音合成功能
|
||||
3. 监听Kafka消息并处理TTS任务
|
||||
4. 将合成结果存储到Redis并更新任务状态
|
||||
"""
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 设置GPU设备(如果可用)
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
# 从环境变量中读取Redis配置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB')) # DB 2用于存储TTS结果
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) # DB 3用于存储任务状态
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# 从环境变量中读取Kafka配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||
|
||||
# 从环境变量中读取TTS相关配置
|
||||
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
|
||||
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
|
||||
REF_AUDIO_PATH = os.getenv('REF_AUDIO_ZN_PATH')
|
||||
REF_TEXT_PATH = os.getenv('REF_TEXT_ZN_PATH')
|
||||
REF_LANGUAGE = os.getenv('REF_LANGUAGE')
|
||||
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
|
||||
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
|
||||
|
||||
# 初始化Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_TTS_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
redis_task_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_TASK_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
# 初始化国际化工具
|
||||
i18n = I18nAuto()
|
||||
|
||||
def get_audio_hash(text):
|
||||
"""
|
||||
生成文本的MD5哈希值,用作音频文件名的一部分
|
||||
|
||||
参数:
|
||||
text (str): 需要生成哈希的文本
|
||||
|
||||
返回:
|
||||
str: 文本的MD5哈希值
|
||||
"""
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# 初始化模型
|
||||
print("正在初始化模型...")
|
||||
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
|
||||
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
|
||||
|
||||
# 读取参考文本
|
||||
with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file:
|
||||
ref_text = file.read()
|
||||
|
||||
print("模型初始化成功。")
|
||||
|
||||
def synthesize(target_text, output_wav_path):
|
||||
"""
|
||||
使用GPT-SoVITS模型合成语音
|
||||
|
||||
参数:
|
||||
target_text (str): 需要合成语音的目标文本
|
||||
output_wav_path (str): 输出音频文件的路径
|
||||
|
||||
返回:
|
||||
str: 如果成功,返回输出音频文件的路径;如果失败,返回None
|
||||
"""
|
||||
with torch.cuda.device(device):
|
||||
synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(REF_LANGUAGE),
|
||||
text=target_text,
|
||||
text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
return output_wav_path
|
||||
else:
|
||||
return None
|
||||
|
||||
def kafka_consumer():
|
||||
"""
|
||||
Kafka消费者函数,用于接收和处理TTS任务
|
||||
|
||||
该函数会持续监听Kafka的TTS主题,接收任务并进行处理:
|
||||
1. 接收任务信息
|
||||
2. 更新任务状态
|
||||
3. 调用synthesize函数合成语音
|
||||
4. 将结果保存到Redis
|
||||
5. 更新任务完成状态
|
||||
"""
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TTS_TOPIC,
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
auto_offset_reset='latest',
|
||||
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
|
||||
)
|
||||
print(f"TTS消费者已启动")
|
||||
for message in consumer:
|
||||
try:
|
||||
task_id = message.value['task_id']
|
||||
target_text = message.value['text']
|
||||
text_hash = message.value['text_hash']
|
||||
|
||||
# 更新任务状态为 "processing"
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "processing")
|
||||
|
||||
output_wav_path = os.path.join(OUTPUT_PATH, f"{text_hash}.wav")
|
||||
|
||||
# 再次检查文件是否存在(以防在此期间被其他进程创建)
|
||||
if not os.path.exists(output_wav_path):
|
||||
output_path = synthesize(target_text, output_wav_path)
|
||||
else:
|
||||
output_path = output_wav_path
|
||||
|
||||
if output_path:
|
||||
# 将结果保存在 DB 2
|
||||
redis_client.set(f"tts:{task_id}", json.dumps({"path": output_path}))
|
||||
print(f"音频合成成功: {output_path}")
|
||||
|
||||
# 更新任务状态为 "completed"
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "completed")
|
||||
else:
|
||||
print("音频合成失败")
|
||||
|
||||
# 更新任务状态为 "failed"
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "failed")
|
||||
except Exception as e:
|
||||
print(f"处理消息时出错: {str(e)}")
|
||||
|
||||
# 更新任务状态为 "failed"
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "failed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 设置CUDA设备
|
||||
torch.cuda.set_device(device)
|
||||
# 启动Kafka消费者
|
||||
kafka_consumer()
|
||||
@@ -1,68 +0,0 @@
|
||||
import os
|
||||
import whisper
|
||||
import argparse
|
||||
|
||||
def transcribe_audio(model, audio_path):
|
||||
"""
|
||||
使用Whisper模型转录音频文件
|
||||
|
||||
:param model: 加载的Whisper模型
|
||||
:param audio_path: 音频文件路径
|
||||
:return: 转录的文本
|
||||
"""
|
||||
try:
|
||||
result = model.transcribe(audio_path)
|
||||
return result["text"]
|
||||
except Exception as e:
|
||||
print(f"转录失败 {audio_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def process_directory(directory, model):
|
||||
"""
|
||||
处理目录中的所有WAV文件
|
||||
|
||||
:param directory: 包含WAV文件的目录路径
|
||||
:param model: 加载的Whisper模型
|
||||
"""
|
||||
for filename in os.listdir(directory):
|
||||
if filename.lower().endswith('.wav'):
|
||||
input_file = os.path.join(directory, filename)
|
||||
output_file = os.path.splitext(input_file)[0] + ".txt"
|
||||
|
||||
print(f"正在处理: {input_file}")
|
||||
transcription = transcribe_audio(model, input_file)
|
||||
|
||||
if transcription:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(transcription)
|
||||
print(f"转录完成: {output_file}")
|
||||
else:
|
||||
print(f"转录失败: {input_file}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="使用Whisper将WAV文件转换为文本")
|
||||
parser.add_argument("input_path", help="输入的WAV文件或包含WAV文件的目录路径")
|
||||
parser.add_argument("--model", default="small", choices=["tiny", "base", "small", "medium", "large", "large-v3"], help="Whisper模型大小")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"正在加载Whisper模型 ({args.model})...")
|
||||
model = whisper.load_model(args.model)
|
||||
print("模型加载完成")
|
||||
|
||||
if os.path.isfile(args.input_path):
|
||||
if not args.input_path.lower().endswith('.wav'):
|
||||
print("错误: 输入文件不是WAV格式")
|
||||
return
|
||||
output_file = os.path.splitext(args.input_path)[0] + ".txt"
|
||||
transcription = transcribe_audio(model, args.input_path)
|
||||
if transcription:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(transcription)
|
||||
print(f"转录完成: {output_file}")
|
||||
elif os.path.isdir(args.input_path):
|
||||
process_directory(args.input_path, model)
|
||||
else:
|
||||
print("错误: 输入路径既不是文件也不是目录")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1 +0,0 @@
|
||||
{"GPT": {"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"}, "SoVITS": {"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"}}
|
||||
@@ -1,186 +0,0 @@
|
||||
from fastapi import FastAPI, File, UploadFile, HTTPException, WebSocket
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import whisper
|
||||
import tempfile
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
import json
|
||||
import threading
|
||||
import os
|
||||
import uuid
|
||||
import asyncio
|
||||
import logging
|
||||
import redis
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 设置要使用的GPU ID
|
||||
GPU_ID = 1 # 修改这个值来选择要使用的GPU
|
||||
|
||||
# 设置CUDA_VISIBLE_DEVICES环境变量
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# CORS 配置
|
||||
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
||||
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
print("正在加载Whisper模型...")
|
||||
model = whisper.load_model("large-v3")
|
||||
print("Whisper模型加载完成。")
|
||||
|
||||
# Kafka配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
||||
|
||||
# Redis配置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_DB = int(os.getenv('REDIS_ASR_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB,
|
||||
password=REDIS_PASSWORD # 添加密码
|
||||
)
|
||||
|
||||
# Kafka生产者
|
||||
producer = KafkaProducer(
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
||||
)
|
||||
|
||||
# 存储WebSocket连接的字典
|
||||
active_connections = {}
|
||||
|
||||
@app.websocket("/asr/ws/{client_id}")
|
||||
async def websocket_endpoint(websocket: WebSocket, client_id: str):
|
||||
await websocket.accept()
|
||||
active_connections[client_id] = websocket
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# 设置接收超时
|
||||
data = await asyncio.wait_for(websocket.receive_text(), timeout=30)
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
else:
|
||||
await websocket.send_text(f"收到消息: {data}")
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
# 发送心跳
|
||||
await websocket.send_text("heartbeat")
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"客户端 {client_id} 断开连接")
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"客户端 {client_id} 断开连接")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket错误: {e}")
|
||||
finally:
|
||||
if client_id in active_connections:
|
||||
del active_connections[client_id]
|
||||
|
||||
|
||||
@app.post("/asr")
|
||||
async def transcribe(audio: UploadFile = File(...)):
|
||||
if not audio:
|
||||
raise HTTPException(status_code=400, detail="未提供音频文件")
|
||||
|
||||
client_id = str(uuid.uuid4())
|
||||
|
||||
# 生成缓存键
|
||||
cache_key = f"asr:{audio.filename}:{client_id}"
|
||||
|
||||
# 检查缓存
|
||||
cached_result = redis_client.get(cache_key)
|
||||
if cached_result:
|
||||
logger.info(f"缓存命中: {cache_key}")
|
||||
return {"message": "从缓存获取转录结果", "transcription": cached_result.decode('utf-8'), "client_id": client_id}
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio:
|
||||
content = await audio.read()
|
||||
temp_audio.write(content)
|
||||
temp_audio.flush()
|
||||
|
||||
task = {
|
||||
'file_path': temp_audio.name,
|
||||
'client_id': client_id,
|
||||
'cache_key': cache_key
|
||||
}
|
||||
producer.send(KAFKA_TOPIC, value=task)
|
||||
producer.flush()
|
||||
|
||||
logger.info(f"发送任务到Kafka: {task}")
|
||||
return {"message": "音频文件已接收并发送任务进行处理", "client_id": client_id}
|
||||
|
||||
async def send_transcription(client_id: str, transcription: str):
|
||||
if client_id in active_connections:
|
||||
websocket = active_connections[client_id]
|
||||
await websocket.send_json({"transcription": transcription})
|
||||
else:
|
||||
logger.warning(f"客户端 {client_id} 的WebSocket连接不存在")
|
||||
|
||||
def kafka_consumer(consumer_id):
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
|
||||
group_id='asr_group',
|
||||
max_poll_interval_ms=300000
|
||||
)
|
||||
|
||||
for message in consumer:
|
||||
try:
|
||||
task = message.value
|
||||
file_path = task.get('file_path')
|
||||
client_id = task.get('client_id')
|
||||
cache_key = task.get('cache_key')
|
||||
|
||||
if not file_path or not client_id or not cache_key:
|
||||
logger.error(f"消费者 {consumer_id} 收到无效任务: {task}")
|
||||
consumer.commit()
|
||||
continue
|
||||
|
||||
result = model.transcribe(file_path)
|
||||
|
||||
logger.info(f"消费者 {consumer_id} 处理了文件: {file_path}")
|
||||
logger.info(f"转录结果: {result['text']}")
|
||||
|
||||
# 将结果存入Redis缓存
|
||||
redis_client.setex(cache_key, 3600, result['text']) # 缓存1小时
|
||||
|
||||
asyncio.run(send_transcription(client_id, result['text']))
|
||||
|
||||
os.remove(file_path)
|
||||
consumer.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"消费者 {consumer_id} 处理消息时发生错误: {str(e)}")
|
||||
|
||||
def start_consumers(num_consumers=1):
|
||||
for i in range(num_consumers):
|
||||
consumer_thread = threading.Thread(target=kafka_consumer, args=(i,))
|
||||
consumer_thread.start()
|
||||
|
||||
if __name__ == '__main__':
|
||||
start_consumers()
|
||||
uvicorn.run(app, host="0.0.0.0", port=6000)
|
||||
+30
-2
@@ -23,6 +23,10 @@ REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||
|
||||
OLLAMA_URL = os.getenv('OLLAMA_URL')
|
||||
OLLAMA_URLS = os.getenv('OLLAMA_URLS', OLLAMA_URL).split(',') # 兼容旧配置
|
||||
OLLAMA_TIMEOUT = int(os.getenv('OLLAMA_TIMEOUT', 10))
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
@@ -51,6 +55,21 @@ def create_kafka_consumer():
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
async def try_ollama_request(url, data):
|
||||
"""尝试向单个 Ollama API 发送请求"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{url}/api/generate",
|
||||
json=data,
|
||||
stream=True,
|
||||
timeout=OLLAMA_TIMEOUT
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except Exception as e:
|
||||
print(f"API {url} 请求失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def process_chat_request(chat_request):
|
||||
try:
|
||||
task_id = chat_request['task_id']
|
||||
@@ -77,8 +96,16 @@ async def process_chat_request(chat_request):
|
||||
"temperature": 0
|
||||
}
|
||||
|
||||
response = requests.post("https://ffgregevrdcfyhtnhyudvr.myfastools.com/api/generate", json=data, stream=True)
|
||||
response.raise_for_status()
|
||||
# 尝试所有可用的 API 地址
|
||||
response = None
|
||||
for url in OLLAMA_URLS:
|
||||
response = await try_ollama_request(url, data)
|
||||
if response is not None:
|
||||
print(f"使用 API 地址: {url}")
|
||||
break
|
||||
|
||||
if response is None:
|
||||
raise Exception("所有 API 地址均不可用")
|
||||
|
||||
text_output = ""
|
||||
for line in response.iter_lines():
|
||||
@@ -105,6 +132,7 @@ async def process_chat_request(chat_request):
|
||||
# 设置任务状态为 "error"
|
||||
redis_task_client.set(f"chat:{task_id}:status", "error")
|
||||
redis_task_client.set(f"chat:{task_id}:error", str(e))
|
||||
|
||||
def kafka_consumer_thread(consumer_id):
|
||||
consumer = create_kafka_consumer()
|
||||
print(f"消费者 {consumer_id} 已启动")
|
||||
|
||||
Regular → Executable
Regular → Executable
Regular → Executable
@@ -0,0 +1 @@
|
||||
{"GPT": {"v1": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", "v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"}, "SoVITS": {"v1": "GPT_SoVITS/pretrained_models/s2G488k.pth", "v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"}}
|
||||
@@ -1,35 +0,0 @@
|
||||
kafka:
|
||||
bootstrap_servers:
|
||||
- "222.186.136.78:9092"
|
||||
value_serializer: "json"
|
||||
topics:
|
||||
all_frames:
|
||||
name: "pose-input"
|
||||
num_consumers: 3
|
||||
ten_seconds:
|
||||
name: "cpm-input"
|
||||
num_consumers: 1
|
||||
input_topic:
|
||||
name: "raw-data"
|
||||
num_consumers: 3
|
||||
|
||||
redis:
|
||||
host: "222.186.136.78"
|
||||
port: 6379
|
||||
db: 0
|
||||
password: "Obscura@2024"
|
||||
|
||||
minio:
|
||||
endpoint: "api.obscura.work"
|
||||
access_key: "00v3MtLtIAIkR3hkIuYR"
|
||||
secret_key: "XfDeVe5bJjPU21NEYc023gzJVUTJzQqxsWHqIKMf"
|
||||
secure: true
|
||||
|
||||
mongodb:
|
||||
uri: "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
|
||||
db_name: "minio_mongo"
|
||||
|
||||
model:
|
||||
pose-path: "worker_sys/function/yolov8n-pose.pt"
|
||||
cpm-path: "worker_sys/OpenBMB/MiniCPM-V-2_6"
|
||||
|
||||
@@ -1,362 +0,0 @@
|
||||
import json
|
||||
import io
|
||||
from PIL import Image
|
||||
import torch
|
||||
from kafka import KafkaConsumer, KafkaProducer
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
import threading
|
||||
import re
|
||||
from datetime import datetime
|
||||
import time
|
||||
import base64
|
||||
import numpy as np
|
||||
import cv2
|
||||
import redis
|
||||
from redis import ConnectionPool
|
||||
import yaml
|
||||
|
||||
# 加载配置文件
|
||||
with open('worker_sys/function/config.yaml', 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
class JSONEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
class ImageSequenceProcessor:
|
||||
def __init__(self, model_dir):
|
||||
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True,
|
||||
attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
self.max_size = 512 # 设置最大尺寸
|
||||
def extract_time_from_key(self, key_name):
|
||||
# 从key_name中提取时间信息
|
||||
time_str = key_name.split('_')[-2] + '_' + key_name.split('_')[-1].split('.')[0]
|
||||
return datetime.strptime(time_str, "%Y%m%d_%H%M%S")
|
||||
|
||||
def process_image_sequence(self, image_data, key_names):
|
||||
frames = []
|
||||
image_times = []
|
||||
for i, (img, key_name) in enumerate(zip(image_data, key_names)):
|
||||
try:
|
||||
if isinstance(img, np.ndarray):
|
||||
# 确保图像是 RGB 格式
|
||||
if img.shape[2] == 3:
|
||||
frame = Image.fromarray(img)
|
||||
else:
|
||||
print(f"Unexpected number of channels for image {i}: {img.shape[2]}")
|
||||
continue
|
||||
else:
|
||||
print(f"Unexpected data type for image {i}: {type(img)}")
|
||||
continue
|
||||
frames.append(frame)
|
||||
image_times.append(self.extract_time_from_key(key_name))
|
||||
print(f"Successfully processed frame {i}")
|
||||
except Exception as e:
|
||||
print(f"Error processing frame {i}: {str(e)}")
|
||||
continue
|
||||
|
||||
if not frames:
|
||||
raise ValueError("No valid frames were processed")
|
||||
|
||||
# 修改时间范围格式
|
||||
start_time = min(image_times)
|
||||
end_time = max(image_times)
|
||||
time_range = {
|
||||
'start': start_time.strftime('%Y-%m-%d %H:%M'),
|
||||
'end': end_time.strftime('%Y-%m-%d %H:%M')
|
||||
}
|
||||
|
||||
# # 计算平均时间间隔(以分钟为单位)
|
||||
# if len(image_times) > 1:
|
||||
# sequence_period_seconds = (image_times[-1] - image_times[0]).total_seconds()/ (len(image_times) - 1)
|
||||
# else:
|
||||
# sequence_period_seconds = 0
|
||||
|
||||
total_duration = (end_time - start_time).total_seconds()
|
||||
num_images = len(image_times)
|
||||
if num_images > 1:
|
||||
sequence_period_seconds = total_duration / (num_images - 1)
|
||||
else:
|
||||
sequence_period_seconds = 0
|
||||
|
||||
question = "Analyze these 3 images as if they were frames from a video. Describe the scene in detail in Chinese, including the setting, number of people, their actions, and any changes or movements observed across the frames."
|
||||
msgs = [
|
||||
{'role': 'user', 'content': frames + [question]},
|
||||
]
|
||||
|
||||
params = {
|
||||
"use_image_id": False,
|
||||
"max_slice_nums": 1
|
||||
}
|
||||
|
||||
answer = self.model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=self.tokenizer,
|
||||
**params
|
||||
)
|
||||
|
||||
extracted_info = self.extract_info(answer)
|
||||
|
||||
return {
|
||||
"original_answer": answer,
|
||||
"extracted_info": extracted_info,
|
||||
"num_frames": len(frames),
|
||||
"time_range": time_range,
|
||||
"sequence_period_seconds": sequence_period_seconds
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def extract_info(answer):
|
||||
info = {
|
||||
"environment": None,
|
||||
"num_people": None,
|
||||
"actions": [],
|
||||
"interactions": [],
|
||||
"objects": [],
|
||||
"furniture": []
|
||||
}
|
||||
|
||||
# 环境提取
|
||||
environments = ["办公室", "室内", "室外", "会议室", "办公"]
|
||||
for env in environments:
|
||||
if env in answer.lower():
|
||||
info["environment"] = env
|
||||
break
|
||||
|
||||
# 改进的人数提取
|
||||
people_patterns = [
|
||||
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
||||
r'(男|女)(性|生|士)',
|
||||
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
||||
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
||||
r'(群众|民众|大众|公众)',
|
||||
r'(男女|老少|老幼|大人|小孩)'
|
||||
]
|
||||
for pattern in people_patterns:
|
||||
match = re.search(pattern, answer)
|
||||
if match:
|
||||
if match.group(1).isdigit():
|
||||
info["num_people"] = int(match.group(1))
|
||||
elif match.group(1) in ['一个', '一']:
|
||||
info["num_people"] = 1
|
||||
else:
|
||||
num_word_to_digit = {
|
||||
'二': 2, '三': 3, '四': 4, '五': 5,
|
||||
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
||||
}
|
||||
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
||||
break
|
||||
|
||||
# 动作和互动提取
|
||||
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
||||
for action in actions:
|
||||
if action in answer:
|
||||
info["actions"].append(action)
|
||||
|
||||
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
||||
for interaction in interactions:
|
||||
if interaction in answer:
|
||||
info["interactions"].append(interaction)
|
||||
|
||||
# 物体和家具提取
|
||||
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
||||
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
||||
|
||||
for obj in objects:
|
||||
if obj in answer:
|
||||
info["objects"].append(obj)
|
||||
|
||||
for item in furniture:
|
||||
if item in answer:
|
||||
info["furniture"].append(item)
|
||||
|
||||
return info
|
||||
|
||||
class ImageSequenceAnalysisSystem:
|
||||
def __init__(self, model_dir):
|
||||
self.image_processor = ImageSequenceProcessor(model_dir)
|
||||
|
||||
def process_image_sequence(self, image_data, cache_keys):
|
||||
print(f"Attempting to process sequence of {len(image_data)} images")
|
||||
start_time = time.time()
|
||||
try:
|
||||
print("Processing new image sequence...")
|
||||
for i, (img, cache_key) in enumerate(zip(image_data, cache_keys)):
|
||||
print(f"Image {i} type: {type(img)}")
|
||||
if isinstance(img, np.ndarray):
|
||||
print(f"Image {i} shape: {img.shape}, dtype: {img.dtype}")
|
||||
else:
|
||||
print(f"Unexpected data type for image {i}")
|
||||
print(f"Image {i} cache_key: {cache_key}")
|
||||
|
||||
result = self.image_processor.process_image_sequence(image_data, cache_keys)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print(f"Processed image sequence for time range: {result['time_range']}")
|
||||
print(f"Average time between frames: {result['sequence_period_seconds']:.2f} minutes")
|
||||
print(f"Processing time: {processing_time:.2f} seconds")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print(f"Error processing image sequence: {str(e)}")
|
||||
print(f"Processing time (including error): {processing_time:.2f} seconds")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
class ImageProcessingNode:
|
||||
def __init__(self, bootstrap_servers, input_topic, model_dir, group_id, redis_pool_db0, redis_pool_db1):
|
||||
self.consumer = KafkaConsumer(
|
||||
input_topic,
|
||||
bootstrap_servers=bootstrap_servers,
|
||||
group_id=group_id,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
|
||||
auto_offset_reset='earliest'
|
||||
)
|
||||
self.producer = KafkaProducer(
|
||||
bootstrap_servers=bootstrap_servers,
|
||||
value_serializer=lambda x: json.dumps(x, cls=JSONEncoder).encode('utf-8')
|
||||
)
|
||||
self.input_topic = input_topic
|
||||
self.analysis_system = ImageSequenceAnalysisSystem(model_dir)
|
||||
self.group_id = group_id
|
||||
self.redis_pool_db0 = redis_pool_db0
|
||||
self.redis_pool_db1 = redis_pool_db1
|
||||
self.lock = threading.Lock()
|
||||
self.image_buffer = []
|
||||
self.image_info_buffer = []
|
||||
self.cache_key_buffer = [] # 新增:用于存储cache_key
|
||||
|
||||
|
||||
def process_and_produce(self, message_value):
|
||||
cache_key = message_value.get('cache_key')
|
||||
etag = message_value.get('etag')
|
||||
size = message_value.get('size')
|
||||
object_key = message_value.get('object')
|
||||
|
||||
# print(f"Consumer {self.group_id} received image with key_hash: {key_hash}")
|
||||
|
||||
try:
|
||||
with redis.Redis(connection_pool=self.redis_pool_db0) as redis_client_db0:
|
||||
img_str = redis_client_db0.get(cache_key)
|
||||
if img_str is None:
|
||||
print(f"Error: Image data not found in Redis db0 for cache_key: {cache_key}")
|
||||
return
|
||||
|
||||
img_data = base64.b64decode(img_str)
|
||||
nparr = np.frombuffer(img_data, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
self.image_buffer.append(img)
|
||||
self.image_info_buffer.append({
|
||||
'key_name': cache_key, # 使用 cache_key 作为 key_name
|
||||
'etag': etag,
|
||||
'size': size,
|
||||
'object_key': object_key
|
||||
# 'key_hash': key_hash
|
||||
})
|
||||
self.cache_key_buffer.append(cache_key)
|
||||
|
||||
if len(self.image_buffer) == 3:
|
||||
result = self.analysis_system.process_image_sequence(self.image_buffer, self.cache_key_buffer)
|
||||
|
||||
if result:
|
||||
result_data = {
|
||||
'results': result,
|
||||
'image_sequence': self.image_info_buffer
|
||||
}
|
||||
|
||||
with self.lock:
|
||||
with redis.Redis(connection_pool=self.redis_pool_db1) as redis_client_db1:
|
||||
# 使用第一张图片的 cache_key 作为 sequence_key
|
||||
sequence_key = f"{self.image_info_buffer[0]['key_name']}"
|
||||
redis_client_db1.set(sequence_key, json.dumps(result_data))
|
||||
print(f"Consumer {self.group_id} processed and saved results to Redis db1 with sequence_key {sequence_key}")
|
||||
print(f"Processed image sequence: {[info['key_name'] for info in self.image_info_buffer]}")
|
||||
|
||||
# 清空缓冲区
|
||||
self.image_buffer = []
|
||||
self.image_info_buffer = []
|
||||
self.cache_key_buffer = []
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def run(self):
|
||||
print(f"Consumer {self.group_id} starting to consume messages from {self.input_topic}")
|
||||
while True:
|
||||
messages = self.consumer.poll(timeout_ms=1000)
|
||||
for tp, records in messages.items():
|
||||
for record in records:
|
||||
message_value = record.value
|
||||
self.process_and_produce(message_value)
|
||||
|
||||
# 每分钟打印一次当前缓冲区状态
|
||||
if len(self.image_buffer) < 3:
|
||||
print(f"Still waiting for more images. Current buffer size: {len(self.image_buffer)}")
|
||||
time.sleep(60) # 等待60秒(1分钟)
|
||||
|
||||
def start_consumer(kafka_bootstrap_servers, input_topic, model_dir, group_id, redis_host, redis_port, redis_password):
|
||||
redis_pool_db0 = ConnectionPool(host=redis_host, port=redis_port, db=0, password=redis_password)
|
||||
redis_pool_db1 = ConnectionPool(host=redis_host, port=redis_port, db=2, password=redis_password)
|
||||
|
||||
node = ImageProcessingNode(
|
||||
kafka_bootstrap_servers,
|
||||
input_topic,
|
||||
model_dir,
|
||||
group_id,
|
||||
redis_pool_db0,
|
||||
redis_pool_db1
|
||||
)
|
||||
print(f"Image Processing Node {group_id} initialized.")
|
||||
print(f"Listening on topic: {input_topic}")
|
||||
node.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Configuration
|
||||
# kafka_bootstrap_servers = ['222.186.136.78:9092']
|
||||
# input_topic = 'cpm-input'
|
||||
# model_dir = 'worker_sys/OpenBMB/MiniCPM-V-2_6'
|
||||
# redis_host = '222.186.136.78'
|
||||
# redis_port = 6379
|
||||
# redis_password = 'Obscura@2024'
|
||||
# num_consumers = 1
|
||||
|
||||
kafka_bootstrap_servers = config['kafka']['bootstrap_servers']
|
||||
input_topic = config['kafka']['topics']['ten_seconds']['name']
|
||||
model_dir = config['model']['cpm-path']
|
||||
redis_host = config['redis']['host']
|
||||
redis_port = config['redis']['port']
|
||||
redis_password = config['redis']['password']
|
||||
num_consumers = config['kafka']['topics']['ten_seconds']['num_consumers']
|
||||
|
||||
# 创建多个消费者线程
|
||||
consumer_threads = []
|
||||
for i in range(num_consumers):
|
||||
group_id = f'cpm_group_{i}'
|
||||
thread = threading.Thread(
|
||||
target=start_consumer,
|
||||
args=(kafka_bootstrap_servers, input_topic, model_dir, group_id, redis_host, redis_port, redis_password)
|
||||
)
|
||||
consumer_threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# 等待所有消费者线程完成
|
||||
for thread in consumer_threads:
|
||||
thread.join()
|
||||
@@ -1,183 +0,0 @@
|
||||
|
||||
import json
|
||||
import cv2
|
||||
import numpy as np
|
||||
from kafka import KafkaConsumer
|
||||
from ultralytics import YOLO
|
||||
import threading
|
||||
import redis
|
||||
import base64
|
||||
import yaml
|
||||
from PIL import Image
|
||||
|
||||
# 加载配置文件
|
||||
with open('worker_sys/function/config.yaml', 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
class YOLOv8nPoseProcessor:
|
||||
def __init__(self, model_path):
|
||||
self.model = YOLO(model_path)
|
||||
|
||||
def process_image(self, img):
|
||||
results = self.model(img)
|
||||
return results
|
||||
|
||||
def format_results(self, results):
|
||||
result_data = []
|
||||
for result in results:
|
||||
boxes = result.boxes.xywh.tolist()
|
||||
keypoints = result.keypoints.xy.tolist() if hasattr(result, 'keypoints') else None
|
||||
classes = result.boxes.cls.tolist()
|
||||
confs = result.boxes.conf.tolist()
|
||||
|
||||
for i, (box, cls, conf) in enumerate(zip(boxes, classes, confs)):
|
||||
result_data.append({
|
||||
'box': box,
|
||||
'keypoints': keypoints[i] if keypoints else None,
|
||||
'class': int(cls),
|
||||
'class_name': self.model.names[int(cls)],
|
||||
'confidence': float(conf)
|
||||
})
|
||||
|
||||
if not result_data:
|
||||
result_data.append({
|
||||
'box': None,
|
||||
'keypoints': None,
|
||||
'class': None,
|
||||
'class_name': None,
|
||||
'confidence': None,
|
||||
'message': 'No object detected'
|
||||
})
|
||||
|
||||
return json.dumps(result_data)
|
||||
|
||||
class ImageProcessingNode:
|
||||
def __init__(self, bootstrap_servers, input_topic, model_path, group_id, redis_host, redis_port, redis_password):
|
||||
self.consumer = KafkaConsumer(
|
||||
input_topic,
|
||||
bootstrap_servers=bootstrap_servers,
|
||||
group_id=group_id,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
|
||||
auto_offset_reset='earliest'
|
||||
)
|
||||
self.input_topic = input_topic
|
||||
self.yolo_processor = YOLOv8nPoseProcessor(model_path)
|
||||
self.group_id = group_id
|
||||
self.redis_client_db0 = redis.Redis(host=redis_host, port=redis_port, db=0, password=redis_password)
|
||||
self.redis_client_db1 = redis.Redis(host=redis_host, port=redis_port, db=1, password=redis_password)
|
||||
|
||||
def process_and_produce(self, message):
|
||||
|
||||
cache_key = message.get('cache_key')
|
||||
object_key = message.get('object')
|
||||
etag = message.get('etag')
|
||||
size = message.get('size')
|
||||
|
||||
print(f"Consumer {self.group_id} processing image with cache_key: {cache_key}")
|
||||
|
||||
|
||||
# 从Redis db0获取图片数据
|
||||
img_str = self.redis_client_db0.get(cache_key)
|
||||
if img_str is None:
|
||||
print(f"Error: Image data not found in Redis db0 for cache_key: {cache_key}")
|
||||
return
|
||||
|
||||
# 将base64编码的图片数据转换为OpenCV格式
|
||||
try:
|
||||
img_data = base64.b64decode(img_str)
|
||||
nparr = np.frombuffer(img_data, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
if img is None:
|
||||
print(f"Error: Unable to decode image data for cache_key: {cache_key}")
|
||||
return
|
||||
|
||||
results = self.yolo_processor.process_image(img)
|
||||
formatted_results = self.yolo_processor.format_results(results)
|
||||
|
||||
# 创建包含所有必要信息的字典
|
||||
result_data = {
|
||||
"pose_results": json.loads(formatted_results),
|
||||
'cache_key': cache_key,
|
||||
"object": object_key,
|
||||
"etag": etag,
|
||||
"size": size
|
||||
}
|
||||
|
||||
# 将结果保存到Redis db1
|
||||
try:
|
||||
serialized_data = json.dumps(result_data)
|
||||
if not serialized_data:
|
||||
print(f"Error: Serialized data is empty for cache_key: {cache_key}")
|
||||
return
|
||||
|
||||
self.redis_client_db1.set(cache_key, serialized_data)
|
||||
print(f"Consumer {self.group_id} processed and saved results to Redis db1 with cache_key {cache_key}")
|
||||
except TypeError as e:
|
||||
print(f"Error serializing data: {e}")
|
||||
print(f"Problematic data: {result_data}")
|
||||
except Exception as e:
|
||||
print(f"Unexpected error when saving to Redis: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {e}")
|
||||
|
||||
|
||||
def run(self):
|
||||
print(f"Consumer {self.group_id} starting to consume messages from {self.input_topic}")
|
||||
for message in self.consumer:
|
||||
message_value = message.value
|
||||
cache_key = message_value.get('cache_key')
|
||||
if cache_key:
|
||||
threading.Thread(target=self.process_and_produce, args=(message_value,)).start()
|
||||
else:
|
||||
print(f"Consumer {self.group_id} error: Received message without cache_key")
|
||||
|
||||
def start_consumer(kafka_bootstrap_servers, input_topic, model_path, group_id, redis_host, redis_port, redis_password):
|
||||
node = ImageProcessingNode(
|
||||
kafka_bootstrap_servers,
|
||||
input_topic,
|
||||
model_path,
|
||||
group_id,
|
||||
redis_host,
|
||||
redis_port,
|
||||
redis_password
|
||||
)
|
||||
print(f"Image Processing Node {group_id} initialized.")
|
||||
print(f"Listening on topic: {input_topic}")
|
||||
node.run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建多个消费者线程
|
||||
# Kafka 配置
|
||||
kafka_bootstrap_servers = config['kafka']['bootstrap_servers']
|
||||
input_topic = config['kafka']['topics']['all_frames']['name']
|
||||
num_consumers = config['kafka']['topics']['all_frames']['num_consumers']
|
||||
# 模型路径
|
||||
model_path = config['model']['pose-path']
|
||||
# Redis 配置
|
||||
redis_host = config['redis']['host']
|
||||
redis_port = config['redis']['port']
|
||||
redis_password = config['redis']['password']
|
||||
|
||||
|
||||
consumer_threads = []
|
||||
for i in range(num_consumers):
|
||||
group_id = f'pose_group_{i}'
|
||||
thread = threading.Thread(
|
||||
target=start_consumer,
|
||||
args=(
|
||||
kafka_bootstrap_servers,
|
||||
input_topic,
|
||||
model_path,
|
||||
group_id,
|
||||
redis_host,
|
||||
redis_port,
|
||||
redis_password
|
||||
)
|
||||
)
|
||||
consumer_threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# 等待所有消费者线程完成
|
||||
for thread in consumer_threads:
|
||||
thread.join()
|
||||
@@ -1,204 +0,0 @@
|
||||
import json
|
||||
import yaml
|
||||
from kafka import KafkaConsumer, KafkaProducer
|
||||
import time
|
||||
from datetime import datetime
|
||||
import redis
|
||||
from minio import Minio
|
||||
import io
|
||||
from PIL import Image
|
||||
import base64
|
||||
import traceback
|
||||
|
||||
|
||||
# 加载配置文件
|
||||
with open('worker_sys/function/config.yaml', 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
# Kafka 配置
|
||||
kafka_config = {
|
||||
'bootstrap_servers': config['kafka']['bootstrap_servers'],
|
||||
'value_serializer': lambda v: json.dumps(v).encode('utf-8') if config['kafka']['value_serializer'] == 'json' else None
|
||||
}
|
||||
|
||||
# 创建Kafka生产者
|
||||
producer = KafkaProducer(**kafka_config)
|
||||
|
||||
# 创建Kafka消费者(用于输入主题)
|
||||
consumer = KafkaConsumer(
|
||||
config['kafka']['input_topic']['name'],
|
||||
group_id='image-processor',
|
||||
bootstrap_servers=config['kafka']['bootstrap_servers']
|
||||
)
|
||||
|
||||
# Redis 配置
|
||||
redis_config = config['redis']
|
||||
|
||||
# 创建 Redis 客户端
|
||||
redis_client = redis.Redis(**redis_config)
|
||||
|
||||
# MinIO 配置
|
||||
minio_config = config['minio']
|
||||
|
||||
# 创建 MinIO 客户端
|
||||
minio_client = Minio(
|
||||
minio_config['endpoint'],
|
||||
access_key=minio_config['access_key'],
|
||||
secret_key=minio_config['secret_key'],
|
||||
secure=minio_config['secure']
|
||||
)
|
||||
|
||||
# Kafka topics
|
||||
topic_all_frames = config['kafka']['topics']['all_frames']['name']
|
||||
topic_ten_seconds = config['kafka']['topics']['ten_seconds']['name']
|
||||
topic_input = config['kafka']['input_topic']['name']
|
||||
|
||||
# 消费者数量
|
||||
NUM_CONSUMERS_ALL_FRAMES = config['kafka']['topics']['all_frames']['num_consumers']
|
||||
NUM_CONSUMERS_TEN_SECONDS = config['kafka']['topics']['ten_seconds']['num_consumers']
|
||||
NUM_CONSUMERS_INPUT = config['kafka']['input_topic']['num_consumers']
|
||||
|
||||
def parse_key_name(key_name):
|
||||
parts = key_name.split('/')
|
||||
bucket = parts[0]
|
||||
image_name = '/'.join(parts[1:])
|
||||
image_parts = image_name.rsplit('.', 1)[0].split('_')
|
||||
camera_name = image_parts[0]
|
||||
timestamp = '_'.join(image_parts[1:3])
|
||||
return bucket, camera_name, timestamp
|
||||
|
||||
def should_send_to_ten_seconds_topic(timestamp):
|
||||
try:
|
||||
dt = datetime.strptime(timestamp, '%Y%m%d_%H%M%S')
|
||||
return dt.second % 10 == 0
|
||||
except ValueError as e:
|
||||
print(f"Error parsing timestamp '{timestamp}': {str(e)}")
|
||||
return False
|
||||
|
||||
def get_image_from_minio_and_cache(key_name):
|
||||
bucket, object_name = key_name.split('/', 1)
|
||||
cache_key = f"{key_name}"
|
||||
|
||||
try:
|
||||
response = minio_client.get_object(bucket, object_name)
|
||||
image_data = response.read()
|
||||
print(f"Successfully retrieved image from MinIO: {bucket}/{object_name}")
|
||||
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
redis_client.setex(cache_key, 86400, img_str)
|
||||
print(f"Successfully cached image in Redis: {cache_key}")
|
||||
|
||||
return cache_key
|
||||
except Exception as e:
|
||||
print(f"Error in get_image_from_minio_and_cache: {str(e)}")
|
||||
print("Traceback:")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def process_message(message):
|
||||
|
||||
value_data = json.loads(message.value.decode('utf-8'))
|
||||
key_name = value_data['Key']
|
||||
print(f"Processing key: {key_name}")
|
||||
|
||||
# 从 key_name 中提取 bucket
|
||||
parts = key_name.split('/', 1)
|
||||
if len(parts) != 2:
|
||||
print(f"Error: Invalid key format. Expected 'bucket/object', got '{key_name}'")
|
||||
return
|
||||
|
||||
bucket, object_name = parts
|
||||
|
||||
# 解析 camera_name 和 timestamp
|
||||
camera_name, timestamp = parse_object_name(object_name)
|
||||
|
||||
cache_key = get_image_from_minio_and_cache(key_name)
|
||||
|
||||
if cache_key is None:
|
||||
print(f"Failed to process image: {key_name}")
|
||||
return
|
||||
# 提取 etag 和 size 信息
|
||||
object_info = value_data['Records'][0]['s3']['object']
|
||||
etag = object_info.get('eTag', '')
|
||||
size = object_info.get('size', 0)
|
||||
|
||||
|
||||
message_data = {
|
||||
'bucket': bucket,
|
||||
'camera_name': camera_name,
|
||||
'timestamp': timestamp,
|
||||
'object': object_name,
|
||||
'cache_key': cache_key,
|
||||
'etag': etag,
|
||||
'size': size
|
||||
}
|
||||
|
||||
# 发送到 pose-input 主题
|
||||
producer.send(topic_all_frames, value=message_data)
|
||||
print(f"Sent message to {topic_all_frames}: {key_name}")
|
||||
|
||||
# 发送到 cpm-input 主题
|
||||
producer.send(topic_ten_seconds, value=message_data)
|
||||
print(f"Sent message to {topic_ten_seconds}: {key_name}")
|
||||
|
||||
# #只有在满足特定条件时才发送到 cpm-input 主题
|
||||
# if should_send_to_ten_seconds_topic(timestamp):
|
||||
# producer.send(TOPIC_TEN_SECONDS, value=message_data)
|
||||
# print(f"Sent message to {TOPIC_TEN_SECONDS}: {key_name}")
|
||||
|
||||
producer.flush()
|
||||
print(f"Successfully processed and sent messages for: {key_name}")
|
||||
|
||||
|
||||
def parse_object_name(object_name):
|
||||
# 假设对象名格式为 "cameraX_YYYYMMDD_HHMMSS.jpg"
|
||||
parts = object_name.split('_')
|
||||
if len(parts) != 3:
|
||||
raise ValueError(f"Invalid object name format: {object_name}")
|
||||
|
||||
camera_name = parts[0]
|
||||
timestamp = f"{parts[1]}_{parts[2].split('.')[0]}"
|
||||
return camera_name, timestamp
|
||||
|
||||
def get_image_from_minio_and_cache(key_name):
|
||||
print(f"Received key_name: {key_name}")
|
||||
|
||||
try:
|
||||
bucket, object_name = key_name.split('/', 1)
|
||||
except ValueError as e:
|
||||
print(f"Error splitting key_name '{key_name}': {str(e)}")
|
||||
return None
|
||||
|
||||
cache_key = f"{key_name}"
|
||||
|
||||
try:
|
||||
response = minio_client.get_object(bucket, object_name)
|
||||
image_data = response.read()
|
||||
print(f"Successfully retrieved image from MinIO: {bucket}/{object_name}")
|
||||
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
redis_client.setex(cache_key, 86400, img_str)
|
||||
print(f"Successfully cached image in Redis: {cache_key}")
|
||||
|
||||
return cache_key
|
||||
except Exception as e:
|
||||
print(f"Error in get_image_from_minio_and_cache: {str(e)}")
|
||||
print("Traceback:")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def main():
|
||||
print("Starting image processing service...")
|
||||
for message in consumer:
|
||||
process_message(message)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,91 +0,0 @@
|
||||
import redis
|
||||
import pymongo
|
||||
import time
|
||||
import json
|
||||
import yaml
|
||||
|
||||
def sync_data(redis_client, mongo_db, redis_db, mongo_collection_name):
|
||||
mongo_collection = mongo_db[mongo_collection_name]
|
||||
all_keys = redis_client.keys('*')
|
||||
synced_count = 0
|
||||
|
||||
for key in all_keys:
|
||||
try:
|
||||
poem_data = redis_client.get(key)
|
||||
|
||||
if poem_data:
|
||||
poem_dict = json.loads(poem_data)
|
||||
|
||||
# 检查MongoDB中是否已存在相同的文档
|
||||
existing_doc = mongo_collection.find_one({'_id': key.decode('utf-8')})
|
||||
|
||||
if not existing_doc or existing_doc != poem_dict:
|
||||
mongo_collection.update_one(
|
||||
{'_id': key.decode('utf-8')},
|
||||
{'$set': poem_dict},
|
||||
upsert=True
|
||||
)
|
||||
print(f"Synced data with key: {key} from Redis DB {redis_db} to MongoDB collection '{mongo_collection_name}'")
|
||||
synced_count += 1
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error decoding JSON for key: {key} in Redis DB {redis_db}")
|
||||
except Exception as e:
|
||||
print(f"Error syncing data for key {key} in Redis DB {redis_db}: {str(e)}")
|
||||
|
||||
return synced_count
|
||||
|
||||
def main(config):
|
||||
# Redis配置
|
||||
redis_host = config['redis']['host']
|
||||
redis_port = config['redis']['port']
|
||||
redis_password = config['redis']['password']
|
||||
|
||||
# MongoDB配置
|
||||
mongo_uri = config['mongodb']['uri']
|
||||
mongo_db_name = config['mongodb']['db_name']
|
||||
|
||||
# 连接到MongoDB
|
||||
mongo_client = pymongo.MongoClient(mongo_uri)
|
||||
mongo_db = mongo_client[mongo_db_name]
|
||||
|
||||
# 固定的Redis数据库和MongoDB集合映射
|
||||
db_collection_map = {
|
||||
0: 'pose-result-db0',
|
||||
1: 'pose-result-db1',
|
||||
2: 'cpm-result-db2'
|
||||
}
|
||||
|
||||
print("Selected databases and collections for syncing:")
|
||||
for db, collection in db_collection_map.items():
|
||||
print(f" Redis DB {db} -> MongoDB collection '{collection}'")
|
||||
|
||||
while True:
|
||||
print("Starting sync...")
|
||||
total_synced = 0
|
||||
for db, collection in db_collection_map.items():
|
||||
print(f"Syncing Redis DB {db} to MongoDB collection '{collection}'...")
|
||||
try:
|
||||
redis_client = redis.Redis(host=redis_host, port=redis_port, db=db, password=redis_password)
|
||||
synced_count = sync_data(redis_client, mongo_db, db, collection)
|
||||
total_synced += synced_count
|
||||
except redis.exceptions.AuthenticationError:
|
||||
print(f"Error: Authentication failed for Redis DB {db}. Skipping...")
|
||||
except redis.exceptions.ConnectionError:
|
||||
print(f"Error: Unable to connect to Redis DB {db}. Skipping...")
|
||||
except Exception as e:
|
||||
print(f"Error occurred while syncing Redis DB {db}: {str(e)}. Skipping...")
|
||||
|
||||
if total_synced > 0:
|
||||
print(f"Sync completed. {total_synced} documents synced. Waiting for next update...")
|
||||
else:
|
||||
print("No new data to sync. Waiting for next update...")
|
||||
|
||||
time.sleep(300) # 等待5分钟后再次同步
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 加载配置文件
|
||||
with open('worker_sys/function/config.yaml', 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
# 运行主程序
|
||||
main(config)
|
||||
@@ -1,299 +0,0 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
import json
|
||||
import re
|
||||
from pymongo import MongoClient
|
||||
import time
|
||||
from bson import ObjectId
|
||||
import os
|
||||
import glob
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# 数据库连接模块
|
||||
class DatabaseHandler:
|
||||
def __init__(self, mongo_uri, database_name, results_collection_name):
|
||||
self.client = MongoClient(mongo_uri)
|
||||
self.db = self.client[database_name]
|
||||
self.results_collection = self.db[results_collection_name]
|
||||
|
||||
def save_result(self, result):
|
||||
# 如果 result 中没有 filename,使用时间戳作为替代
|
||||
# filename = result.get('filename', f"unknown_{result['timestamp']}")
|
||||
filename = result.get('filename')
|
||||
# 检查是否已存在相同 filename 的结果
|
||||
existing_result = self.results_collection.find_one({'filename': filename})
|
||||
if existing_result:
|
||||
print(f"Video with filename {filename} has already been processed. Skipping.")
|
||||
return
|
||||
|
||||
# 确保 result 中有 filename
|
||||
result['filename'] = filename
|
||||
|
||||
# 将 ObjectId 转换为字符串
|
||||
if 'video_id' in result and isinstance(result['video_id'], ObjectId):
|
||||
result['video_id'] = str(result['video_id'])
|
||||
|
||||
self.results_collection.insert_one(result)
|
||||
|
||||
def is_sequence_processed(self, filename):
|
||||
return self.results_collection.find_one({'filename': filename}) is not None
|
||||
|
||||
|
||||
class JSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, ObjectId):
|
||||
return str(o)
|
||||
return super().default(o)
|
||||
|
||||
# 视频处理模块
|
||||
class ImageSequenceProcessor:
|
||||
def __init__(self, model_dir):
|
||||
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True,
|
||||
attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
self.max_size = 512 # 设置最大尺寸
|
||||
|
||||
def compress_image(self, image):
|
||||
# 保持纵横比的情况下调整图片大小
|
||||
image.thumbnail((self.max_size, self.max_size))
|
||||
|
||||
# 如果图像已经是JPEG格式,直接返回调整大小后的图像
|
||||
if image.format == 'JPEG':
|
||||
return image
|
||||
|
||||
# 对于非JPEG格式,进行压缩
|
||||
buffer = io.BytesIO()
|
||||
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
|
||||
# 保持透明度
|
||||
image.save(buffer, format="PNG", optimize=True)
|
||||
else:
|
||||
# 转换为JPEG并压缩
|
||||
image.convert('RGB').save(buffer, format="JPEG", quality=85, optimize=True)
|
||||
|
||||
buffer.seek(0)
|
||||
return Image.open(buffer)
|
||||
|
||||
def process_image_sequence(self, image_paths):
|
||||
frames = [self.compress_image(Image.open(img_path)) for img_path in image_paths]
|
||||
question = "Analyze these 10 images as if they were frames from a video. Describe the scene in detail in Chinese, including the setting, number of people, their actions, and any changes or movements observed across the frames."
|
||||
msgs = [
|
||||
{'role': 'user', 'content': frames + [question]},
|
||||
]
|
||||
|
||||
params = {
|
||||
"use_image_id": False,
|
||||
"max_slice_nums": 1
|
||||
}
|
||||
|
||||
answer = self.model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=self.tokenizer,
|
||||
**params
|
||||
)
|
||||
|
||||
extracted_info = self.extract_info(answer)
|
||||
|
||||
return {
|
||||
"original_answer": answer,
|
||||
"extracted_info": extracted_info,
|
||||
"num_frames": len(frames),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def extract_info(answer):
|
||||
info = {
|
||||
"environment": None,
|
||||
"num_people": None,
|
||||
"actions": [],
|
||||
"interactions": [],
|
||||
"objects": [],
|
||||
"furniture": []
|
||||
}
|
||||
|
||||
# 环境提取
|
||||
environments = ["办公室", "室内", "室外", "会议室", "办公"]
|
||||
for env in environments:
|
||||
if env in answer.lower():
|
||||
info["environment"] = env
|
||||
break
|
||||
|
||||
# 改进的人数提取
|
||||
people_patterns = [
|
||||
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
||||
r'(男|女)(性|生|士)',
|
||||
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
||||
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
||||
r'(群众|民众|大众|公众)',
|
||||
r'(男女|老少|老幼|大人|小孩)'
|
||||
]
|
||||
for pattern in people_patterns:
|
||||
match = re.search(pattern, answer)
|
||||
if match:
|
||||
if match.group(1).isdigit():
|
||||
info["num_people"] = int(match.group(1))
|
||||
elif match.group(1) in ['一个', '一']:
|
||||
info["num_people"] = 1
|
||||
else:
|
||||
num_word_to_digit = {
|
||||
'二': 2, '三': 3, '四': 4, '五': 5,
|
||||
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
||||
}
|
||||
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
||||
break
|
||||
|
||||
# 动作和互动提取
|
||||
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
||||
for action in actions:
|
||||
if action in answer:
|
||||
info["actions"].append(action)
|
||||
|
||||
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
||||
for interaction in interactions:
|
||||
if interaction in answer:
|
||||
info["interactions"].append(interaction)
|
||||
|
||||
# 物体和家具提取
|
||||
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
||||
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
||||
|
||||
for obj in objects:
|
||||
if obj in answer:
|
||||
info["objects"].append(obj)
|
||||
|
||||
for item in furniture:
|
||||
if item in answer:
|
||||
info["furniture"].append(item)
|
||||
|
||||
return info
|
||||
|
||||
# 主处理类
|
||||
class ImageSequenceAnalysisSystem:
|
||||
def __init__(self, mongo_uri, db_name, model_dir, results_collection_name):
|
||||
self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name)
|
||||
self.image_processor = ImageSequenceProcessor(model_dir)
|
||||
self.last_processed_time = datetime.now() - timedelta(hours=1)
|
||||
|
||||
def get_all_images(self, image_folders):
|
||||
image_files = []
|
||||
for folder in image_folders:
|
||||
image_files.extend(glob.glob(os.path.join(folder, '*.jpg')))
|
||||
image_files.sort()
|
||||
return image_files
|
||||
|
||||
|
||||
def process_image_sequence(self, image_paths):
|
||||
print(f"Attempting to process sequence: {[os.path.basename(img) for img in image_paths]}")
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 使用第一张图片的文件名作为序列的标识符
|
||||
filename = os.path.basename(image_paths[0])
|
||||
|
||||
if self.db_handler.is_sequence_processed(filename):
|
||||
print(f"Skipping already processed image sequence: {filename}")
|
||||
return False
|
||||
|
||||
print("Processing new image sequence...")
|
||||
result = self.image_processor.process_image_sequence(image_paths)
|
||||
|
||||
# timestamp = datetime.now()
|
||||
# result['timestamp'] = timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
result['image_paths'] = image_paths
|
||||
result['filename'] = filename
|
||||
|
||||
# 计算图片序列的周期
|
||||
image_times = [self.get_file_time(img) for img in image_paths]
|
||||
if len(image_times) >= 2:
|
||||
time_diff = (image_times[-1] - image_times[0]).total_seconds()
|
||||
period_minutes = time_diff / (len(image_times) - 1) / 60
|
||||
result['sequence_period_minutes'] = round(period_minutes, 2)
|
||||
|
||||
# 添加时间段信息
|
||||
result['time_range'] = {
|
||||
'start': image_times[0].strftime("%Y-%m-%d %H:%M"),
|
||||
'end': image_times[-1].strftime("%Y-%m-%d %H:%M")
|
||||
}
|
||||
|
||||
save_result = self.db_handler.save_result(result)
|
||||
print(f"Result saved to: {self.db_handler.results_collection.name}")
|
||||
print(f"Result filename: {filename}")
|
||||
return save_result
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print(f"Error processing image sequence: {str(e)}")
|
||||
print(f"Processing time (including error): {processing_time:.2f} seconds")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
# @staticmethod
|
||||
# def extract_time_from_filename(filename):
|
||||
# # 假设文件名格式为 "YYYYMMDDHHMMSS.jpg"
|
||||
# time_str = filename.split('.')[0]
|
||||
# return datetime.strptime(time_str, "%Y%m%d%H%M")
|
||||
|
||||
@staticmethod
|
||||
def get_file_time(file_path):
|
||||
# 获取文件的修改时间
|
||||
mod_time = os.path.getmtime(file_path)
|
||||
return datetime.fromtimestamp(mod_time)
|
||||
|
||||
def process_all_unprocessed_images(self, image_folders):
|
||||
print(f"Searching for unprocessed images in: {image_folders}")
|
||||
all_images = self.get_all_images(image_folders)
|
||||
print(f"Found {len(all_images)} images in total")
|
||||
selected_images = all_images[::10]
|
||||
# Group images into sequences of 3
|
||||
image_sequences = [selected_images[i:i+10] for i in range(0, len(selected_images), 10)]
|
||||
# image_sequences = [all_images[i:i+10] for i in range(0, len(all_images), 10)]
|
||||
|
||||
processed_sequences = 0
|
||||
for sequence in image_sequences:
|
||||
if len(sequence) == 10:
|
||||
self.process_image_sequence(sequence)
|
||||
processed_sequences += 1
|
||||
else:
|
||||
print(f"Warning: Incomplete sequence. Found {len(sequence)} images.")
|
||||
|
||||
if processed_sequences == 0:
|
||||
print("All current photos have been processed. Waiting for new photos...")
|
||||
else:
|
||||
print(f"Processed {processed_sequences} sequences.")
|
||||
|
||||
def run(self, root_folders):
|
||||
print(f"Starting the system with root folder: {', '.join(root_folders)}")
|
||||
|
||||
while True:
|
||||
current_time = datetime.now()
|
||||
time_since_last_process = (current_time - self.last_processed_time).total_seconds()
|
||||
|
||||
if time_since_last_process >= 3600: # 1小时 = 3600秒
|
||||
self.process_all_unprocessed_images(root_folders)
|
||||
self.last_processed_time = current_time
|
||||
|
||||
# 计算下次检查的等待时间
|
||||
wait_time = max(0, 3600 - (datetime.now() - self.last_processed_time).total_seconds())
|
||||
print(f"Waiting for new photos... Next check in {wait_time:.0f} seconds.")
|
||||
|
||||
time.sleep(60) # 每分钟检查一次是否需要处理
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
|
||||
db_name = "minio_mongo"
|
||||
results_collection_name = "cpm"
|
||||
|
||||
model_dir = "worker_sys/OpenBMB/MiniCPM-V-2_6"
|
||||
|
||||
root_folders = [
|
||||
"/www/wwwroot/zj.obscura.ac.cn/ipcam/Office/Cam2/CapturePics" ,
|
||||
"/www/wwwroot/zj.obscura.ac.cn/ipcam/Office/Cam1/CapturePics"
|
||||
] # 修改为 cam1 文件夹的路径
|
||||
|
||||
system = ImageSequenceAnalysisSystem(mongo_uri, db_name, model_dir, results_collection_name)
|
||||
system.run(root_folders)
|
||||
@@ -1,194 +0,0 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
import json
|
||||
from pymongo import MongoClient
|
||||
import os
|
||||
import glob
|
||||
from datetime import datetime, timedelta
|
||||
from ultralytics import YOLO
|
||||
from bson import ObjectId
|
||||
import time
|
||||
|
||||
# 数据库连接模块
|
||||
class DatabaseHandler:
|
||||
def __init__(self, mongo_uri, database_name, results_collection_name):
|
||||
self.client = MongoClient(mongo_uri)
|
||||
self.db = self.client[database_name]
|
||||
self.results_collection = self.db[results_collection_name]
|
||||
|
||||
def save_result(self, result):
|
||||
filename = result.get('filename', f"unknown_{result['timestamp']}")
|
||||
|
||||
existing_result = self.results_collection.find_one({'filename': filename})
|
||||
if existing_result:
|
||||
print(f"Image with filename {filename} has already been processed. Skipping.")
|
||||
return False
|
||||
|
||||
result['filename'] = filename
|
||||
|
||||
if 'image_id' in result and isinstance(result['image_id'], ObjectId):
|
||||
result['image_id'] = str(result['image_id'])
|
||||
|
||||
self.results_collection.insert_one(result)
|
||||
return True
|
||||
|
||||
def is_image_processed(self, filename):
|
||||
return self.results_collection.find_one({'filename': filename}) is not None
|
||||
|
||||
class JSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, ObjectId):
|
||||
return str(o)
|
||||
return super().default(o)
|
||||
|
||||
# YOLOv8nPoseProcessor 类
|
||||
class YOLOv8nPoseProcessor:
|
||||
def __init__(self, model_path):
|
||||
self.model = YOLO(model_path)
|
||||
|
||||
def process_image(self, img):
|
||||
results = self.model(img)
|
||||
return results
|
||||
|
||||
def format_results(self, results):
|
||||
result_data = []
|
||||
for result in results:
|
||||
boxes = result.boxes.xywh.tolist()
|
||||
keypoints = result.keypoints.xy.tolist() if hasattr(result, 'keypoints') else None
|
||||
classes = result.boxes.cls.tolist()
|
||||
confs = result.boxes.conf.tolist()
|
||||
|
||||
for i, (box, cls, conf) in enumerate(zip(boxes, classes, confs)):
|
||||
result_data.append({
|
||||
'box': box,
|
||||
'keypoints': keypoints[i] if keypoints else None,
|
||||
'class': int(cls),
|
||||
'class_name': self.model.names[int(cls)],
|
||||
'confidence': float(conf)
|
||||
})
|
||||
|
||||
if not result_data:
|
||||
result_data.append({
|
||||
'box': None,
|
||||
'keypoints': None,
|
||||
'class': None,
|
||||
'class_name': None,
|
||||
'confidence': None,
|
||||
'message': 'No object detected'
|
||||
})
|
||||
|
||||
return json.dumps(result_data)
|
||||
|
||||
# 主处理类
|
||||
class ImageAnalysisSystem:
|
||||
def __init__(self, mongo_uri, db_name, model_path, results_collection_name):
|
||||
self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name)
|
||||
self.image_processor = YOLOv8nPoseProcessor(model_path)
|
||||
self.last_processed_time = datetime.now() - timedelta(hours=1)
|
||||
|
||||
def get_all_images(self, image_folders):
|
||||
image_files = []
|
||||
for folder in image_folders:
|
||||
image_files.extend(glob.glob(os.path.join(folder, '*.jpg')))
|
||||
image_files.sort()
|
||||
return image_files
|
||||
@staticmethod
|
||||
def get_file_time(file_path):
|
||||
# 获取文件的修改时间
|
||||
mod_time = os.path.getmtime(file_path)
|
||||
return datetime.fromtimestamp(mod_time)
|
||||
def process_image(self, image_path):
|
||||
print(f"Attempting to process image: {os.path.basename(image_path)}")
|
||||
try:
|
||||
# json_folder = os.path.join("/www/wwwroot/zj.obscura.ac.cn/ipcam/Office/Cam2", 'json')
|
||||
# json_filename = f"{os.path.basename(image_path).split('.')[0]}.json"
|
||||
# json_path = os.path.join(json_folder, json_filename)
|
||||
|
||||
# if os.path.exists(json_path):
|
||||
# print(f"Skipping already processed image: {json_path}")
|
||||
# return
|
||||
|
||||
filename = os.path.basename(image_path)
|
||||
|
||||
if self.db_handler.is_image_processed(filename):
|
||||
print(f"Skipping already processed image: {filename}")
|
||||
return False
|
||||
|
||||
print("Processing new image...")
|
||||
|
||||
image = Image.open(image_path)
|
||||
results = self.image_processor.process_image(image)
|
||||
formatted_results = self.image_processor.format_results(results)
|
||||
|
||||
# timestamp = datetime.now()
|
||||
file_timestamp = self.get_file_time(image_path)
|
||||
result = {
|
||||
'timestamp': file_timestamp.strftime("%Y%m%d_%H%M%S"),
|
||||
'image_path': image_path,
|
||||
'filename': os.path.basename(image_path),
|
||||
'results': json.loads(formatted_results)
|
||||
}
|
||||
|
||||
# os.makedirs(json_folder, exist_ok=True)
|
||||
# with open(json_path, 'w', encoding='utf-8') as f:
|
||||
# json.dump(result, f, ensure_ascii=False, indent=4, cls=JSONEncoder)
|
||||
|
||||
# self.db_handler.save_result(result)
|
||||
# print(f"Processed image at: {timestamp}")
|
||||
# print(f"JSON saved to: {json_path}")
|
||||
# print(f"result saved to: {results_collection_name}")
|
||||
if self.db_handler.save_result(result):
|
||||
print(f"Result saved to: {self.db_handler.results_collection.name}")
|
||||
return True
|
||||
else:
|
||||
print(f"Image {filename} was already in the database. Skipping.")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def process_all_unprocessed_images(self, image_folders):
|
||||
print(f"Searching for unprocessed images in: {image_folders}")
|
||||
all_images = self.get_all_images(image_folders)
|
||||
print(f"Found {len(all_images)} images in total")
|
||||
|
||||
processed_count = 0
|
||||
for image_path in all_images:
|
||||
if self.process_image(image_path):
|
||||
processed_count += 1
|
||||
|
||||
return processed_count
|
||||
|
||||
def run(self, root_folder):
|
||||
print(f"Starting the system with root folder: {', '.join(root_folder)}")
|
||||
# image_folder = os.path.join(root_folder)
|
||||
|
||||
while True:
|
||||
print("Checking for unprocessed images...")
|
||||
processed_count = self.process_all_unprocessed_images(root_folder)
|
||||
|
||||
if processed_count > 0:
|
||||
print(f"Finished processing {processed_count} images.")
|
||||
else:
|
||||
print("No new images to process. Waiting for new images...")
|
||||
|
||||
# 等待一段时间后再次检查新图片
|
||||
time.sleep(60) # 每分钟检查一次是否有新图片
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
|
||||
db_name = "minio_mongo"
|
||||
results_collection_name = "pose"
|
||||
|
||||
model_path = "worker_sys/function/yolov8x-pose.pt" # 请确保这个路径指向你的YOLO-Pose模型文件
|
||||
|
||||
root_folder = [
|
||||
"/www/wwwroot/zj.obscura.ac.cn/ipcam/Office/Cam2/CapturePics" ,
|
||||
"/www/wwwroot/zj.obscura.ac.cn/ipcam/Office/Cam1/CapturePics"
|
||||
] # 修改为 cam1 文件夹的路径
|
||||
|
||||
system = ImageAnalysisSystem(mongo_uri, db_name, model_path, results_collection_name)
|
||||
system.run(root_folder)
|
||||
@@ -1,426 +0,0 @@
|
||||
# main.py
|
||||
from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Security
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from kafka import KafkaProducer
|
||||
from redis import Redis
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import string
|
||||
from decord import VideoReader
|
||||
from PIL import Image
|
||||
from fastapi.responses import FileResponse
|
||||
import logging
|
||||
from config import *
|
||||
|
||||
app = FastAPI()
|
||||
v1_app = FastAPI()
|
||||
app.mount("/v1", v1_app)
|
||||
|
||||
|
||||
# CORS设置
|
||||
# ALLOWED_ORIGINS = ['https://beta.obscura.work']
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
KAFKA_BROKER = KAFKA_BROKER
|
||||
REDIS_HOST = REDIS_HOST
|
||||
REDIS_PORT = REDIS_PORT
|
||||
REDIS_PASSWORD = REDIS_PASSWORD
|
||||
REDIS_DB = MAIN_REDIS_DB
|
||||
REDIS_API_DB = REDIS_API_DB
|
||||
REDIS_API_USAGE_DB = REDIS_API_USAGE_DB
|
||||
UPLOAD_DIR = UPLOAD_DIR
|
||||
RESULT_DIR = RESULT_DIR
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 定义支持的任务类型
|
||||
KAFKA_TOPICS = {
|
||||
'pose': 'pose',
|
||||
'mediapipe': 'mediapipe',
|
||||
'qwenvl': 'qwenvl',
|
||||
'yolo': 'yolo',
|
||||
'fall': 'fall',
|
||||
'face': 'face',
|
||||
'cpm': 'cpm',
|
||||
'compare': 'compare'
|
||||
}
|
||||
|
||||
TASK_TYPES = list(KAFKA_TOPICS.keys())
|
||||
|
||||
|
||||
# 初始化 Kafka Producer
|
||||
producer = KafkaProducer(
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
redis_api_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_DB
|
||||
)
|
||||
|
||||
redis_api_usage_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_USAGE_DB
|
||||
)
|
||||
redis_pose_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['pose']['redis_db'])
|
||||
redis_cpm_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm']['redis_db'])
|
||||
redis_yolo_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['yolo']['redis_db'])
|
||||
redis_face_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['face']['redis_db'])
|
||||
redis_fall_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['fall']['redis_db'])
|
||||
redis_mediapipe_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['mediapipe']['redis_db'])
|
||||
redis_qwenvl_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl']['redis_db'])
|
||||
redis_compare_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['compare']['redis_db'])
|
||||
@v1_app.get('/favicon.ico', include_in_schema=False)
|
||||
async def favicon():
|
||||
file_name = "favicon.ico"
|
||||
file_path = os.path.join(app.root_path, "static", file_name)
|
||||
if os.path.isfile(file_path):
|
||||
return FileResponse(file_path)
|
||||
else:
|
||||
return {"message": "Favicon not found"}, 404
|
||||
|
||||
# 定义API密钥头部
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
# 定义 base62 字符集
|
||||
BASE62 = string.digits + string.ascii_letters
|
||||
|
||||
# 验证API密钥的函数
|
||||
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||
if api_key and api_key.startswith("Bearer "):
|
||||
key = api_key.split(" ")[1]
|
||||
if key.startswith("obs-"):
|
||||
return key
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的API密钥",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
logging.info(f"验证API密钥: {api_key}")
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
logging.warning(f"API密钥不存在: {api_key}")
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
logging.warning(f"API密钥已停用: {api_key}")
|
||||
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
logging.warning(f"API密钥已过期: {api_key}")
|
||||
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
logging.info(f"API密钥验证成功: {api_key}")
|
||||
return {
|
||||
"api_key": api_key,
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
# 更新总的token使用量
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
# 更新特定模型的token使用量
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||
base_tokens = 0
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||
|
||||
# 基础token:每MB文件大小消耗10个token
|
||||
base_tokens = int((file_size / (1024 * 1024)) * 0.1)
|
||||
|
||||
if file_type == "image":
|
||||
img = Image.open(file_path)
|
||||
width, height = img.size
|
||||
pixel_count = width * height
|
||||
|
||||
# 图片token:每100个像素额外消耗5个token
|
||||
image_tokens = int((pixel_count / 100000000) * 0.1)
|
||||
|
||||
base_tokens += image_tokens
|
||||
|
||||
elif file_type == "video":
|
||||
vr = VideoReader(file_path)
|
||||
fps = vr.get_avg_fps()
|
||||
frame_count = len(vr)
|
||||
width, height = vr[0].shape[1], vr[0].shape[0]
|
||||
|
||||
pixel_count = width * height * frame_count
|
||||
duration = frame_count / fps # 视频时长(秒)
|
||||
|
||||
# 视频token:每100万像素每秒额外消耗1个token
|
||||
video_tokens = int((pixel_count / 100000000) * (duration / 60))
|
||||
|
||||
base_tokens += video_tokens
|
||||
|
||||
return max(1, base_tokens) # 确保至少返回1个token
|
||||
except Exception as e:
|
||||
print(f"计算token时出错: {str(e)}")
|
||||
return 1 # 出错时返回默认值1
|
||||
|
||||
|
||||
|
||||
async def upload_file(file: UploadFile, task_type: str, api_key_info: dict):
|
||||
if task_type not in KAFKA_TOPICS:
|
||||
raise HTTPException(status_code=400, detail="不支持的任务类型")
|
||||
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 计算 token
|
||||
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
tokens_required = calculate_tokens(file_path, file_type)
|
||||
|
||||
# 检查并更新 token 使用量
|
||||
api_key = api_key_info['api_key']
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
await update_token_usage(api_key, tokens_required, task_type)
|
||||
|
||||
# 创建任务记录
|
||||
task_id = str(uuid.uuid4())
|
||||
task_data = {
|
||||
"task_id": task_id,
|
||||
"filename": new_filename,
|
||||
"file_type": file_type,
|
||||
"task_type": task_type,
|
||||
"status": "queued",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# 存储任务信息到Redis
|
||||
redis_client.set(f"task:{task_id}", json.dumps(task_data))
|
||||
logging.info(f"任务信息已存储到Redis: {task_id}")
|
||||
|
||||
# 发送任务到对应的Kafka主题
|
||||
kafka_topic = KAFKA_TOPICS[task_type]
|
||||
producer.send(kafka_topic, task_data)
|
||||
logging.info(f"任务已发送到Kafka主题: {kafka_topic}")
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{task_type}_tokens_used", 0))
|
||||
|
||||
response_data = {
|
||||
"message": "文件已上传并排队等待处理",
|
||||
"task_id": task_id,
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{task_type}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
}
|
||||
logging.info(f"上传文件完成: {task_id}")
|
||||
return JSONResponse(content=response_data)
|
||||
|
||||
# 为每个任务类型创建单独的端点
|
||||
@v1_app.post("/pose")
|
||||
async def upload_pose(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
logging.info(f"收到 /pose端点的请求")
|
||||
return await upload_file(file, task_type="pose", api_key_info=api_key_info)
|
||||
|
||||
@v1_app.post("/cpm")
|
||||
async def upload_cpm(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
return await upload_file(file, task_type="cpm", api_key_info=api_key_info)
|
||||
|
||||
@v1_app.post("/qwenvl")
|
||||
async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
return await upload_file(file, task_type="qwenvl", api_key_info=api_key_info)
|
||||
|
||||
@v1_app.post("/yolo")
|
||||
async def upload_yolo(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
return await upload_file(file, task_type="yolo", api_key_info=api_key_info)
|
||||
|
||||
@v1_app.post("/fall")
|
||||
async def upload_fall(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
return await upload_file(file, task_type="fall", api_key_info=api_key_info)
|
||||
|
||||
@v1_app.post("/face")
|
||||
async def upload_face(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
logging.info(f"收到 /face 端点的请求")
|
||||
return await upload_file(file, task_type="face", api_key_info=api_key_info)
|
||||
|
||||
@v1_app.post("/mediapipe")
|
||||
async def upload_mediapipe(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
return await upload_file(file, task_type="mediapipe", api_key_info=api_key_info)
|
||||
|
||||
@v1_app.post("/compare")
|
||||
async def upload_compare(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
return await upload_file(file, task_type="compare", api_key_info=api_key_info)
|
||||
|
||||
|
||||
@v1_app.get("/result/{task_id}")
|
||||
async def get_result(task_id: str, api_key: str = Depends(get_api_key)):
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
# 从 REDIS_DB (15) 获取任务状态
|
||||
task_info = redis_client.hgetall(f"task:{task_id}")
|
||||
if not task_info:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()}
|
||||
|
||||
if task_info['status'] != 'completed':
|
||||
return {"status": task_info['status'], "message": "Task is not completed yet"}
|
||||
|
||||
result_type = task_info['result_type']
|
||||
result_key = task_info['result_key']
|
||||
|
||||
# 根据任务类型选择相应的 Redis 客户端
|
||||
redis_client_map = {
|
||||
'pose': redis_pose_client,
|
||||
'cpm': redis_cpm_client,
|
||||
'yolo': redis_yolo_client,
|
||||
'face': redis_face_client,
|
||||
'fall': redis_fall_client,
|
||||
'mediapipe': redis_mediapipe_client,
|
||||
'qwenvl': redis_qwenvl_client,
|
||||
'compare': redis_compare_client
|
||||
}
|
||||
|
||||
result_redis = redis_client_map.get(result_type)
|
||||
if not result_redis:
|
||||
raise HTTPException(status_code=400, detail="Unsupported result type")
|
||||
|
||||
result = result_redis.hgetall(result_key)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found")
|
||||
|
||||
result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
|
||||
|
||||
# 将 result 字段解析为 JSON(如果存在)
|
||||
if 'result' in result:
|
||||
result['result'] = json.loads(result['result'])
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"result_type": result_type,
|
||||
"result": result
|
||||
}
|
||||
|
||||
@v1_app.get("/annotated/{task_id}")
|
||||
async def get_annotated_image(task_id: str, api_key: str = Depends(get_api_key)):
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
# 从 REDIS_DB (15) 获取任务信息
|
||||
task_info = redis_client.hgetall(f"task:{task_id}")
|
||||
if not task_info:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()}
|
||||
|
||||
if task_info['status'] != 'completed':
|
||||
raise HTTPException(status_code=400, detail="Task is not completed yet")
|
||||
|
||||
result_type = task_info.get('result_type')
|
||||
result_key = task_info.get('result_key')
|
||||
|
||||
if not result_key:
|
||||
raise HTTPException(status_code=404, detail="Result key not found")
|
||||
|
||||
if result_type in ['cpm', 'qwenvl']:
|
||||
raise HTTPException(status_code=400, detail="Annotated image not available for this task type")
|
||||
|
||||
# 根据任务类型选择相应的 Redis 客户端
|
||||
redis_client_map = {
|
||||
'pose': redis_pose_client,
|
||||
'yolo': redis_yolo_client,
|
||||
'face': redis_face_client,
|
||||
'fall': redis_fall_client,
|
||||
'mediapipe': redis_mediapipe_client,
|
||||
'compare': redis_compare_client
|
||||
}
|
||||
|
||||
result_redis = redis_client_map.get(result_type)
|
||||
if not result_redis:
|
||||
raise HTTPException(status_code=400, detail="Unsupported result type")
|
||||
|
||||
result = result_redis.hgetall(result_key)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found")
|
||||
|
||||
result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
|
||||
|
||||
result_file = result.get('result_file')
|
||||
if not result_file:
|
||||
raise HTTPException(status_code=404, detail="Result file not found")
|
||||
|
||||
file_path = os.path.join(RESULT_DIR, result_file)
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(status_code=404, detail="Result image file not found")
|
||||
|
||||
return FileResponse(file_path, media_type="image/png")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8005)
|
||||
@@ -1,255 +0,0 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from decord import VideoReader, cpu
|
||||
import json
|
||||
import re
|
||||
from pymongo import MongoClient
|
||||
import io
|
||||
from minio import Minio
|
||||
import time
|
||||
from bson import ObjectId
|
||||
import concurrent.futures
|
||||
import os
|
||||
|
||||
|
||||
# Minio连接模块
|
||||
class MinioHandler:
|
||||
def __init__(self, endpoint, access_key, secret_key):
|
||||
self.client = Minio(
|
||||
endpoint,
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
secure=True
|
||||
)
|
||||
|
||||
def get_video_data(self, bucket, object_name):
|
||||
response = self.client.get_object(bucket, object_name)
|
||||
data = response.read()
|
||||
# print(f"Read {len(data)} bytes from Minio for {object_name}")
|
||||
return data
|
||||
|
||||
# 数据库连接模块
|
||||
class DatabaseHandler:
|
||||
def __init__(self, mongo_uri, database_name, results_collection_name):
|
||||
self.client = MongoClient(mongo_uri)
|
||||
self.db = self.client[database_name]
|
||||
self.minio_files_collection = self.db['minio_files']
|
||||
self.results_collection = self.db[results_collection_name]
|
||||
|
||||
def get_unprocessed_videos(self):
|
||||
processed_etags = set(self.results_collection.distinct('etag'))
|
||||
return self.minio_files_collection.find({
|
||||
'bucket_name': 'raw',
|
||||
'object_name': {'$regex': r'^douyin/.*/.+\.(mp4|avi|mov|flv)$'},
|
||||
'etag': {'$nin': list(processed_etags)}
|
||||
})
|
||||
def save_result(self, result):
|
||||
# 检查是否已存在相同 etag 的结果
|
||||
existing_result = self.results_collection.find_one({'etag': result['etag']})
|
||||
if existing_result:
|
||||
print(f"Video with etag {result['etag']} has already been processed. Skipping.")
|
||||
return
|
||||
|
||||
# 将 ObjectId 转换为字符串
|
||||
if 'video_id' in result and isinstance(result['video_id'], ObjectId):
|
||||
result['video_id'] = str(result['video_id'])
|
||||
|
||||
self.results_collection.insert_one(result)
|
||||
|
||||
class JSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, ObjectId):
|
||||
return str(o)
|
||||
return super().default(o)
|
||||
|
||||
|
||||
# 视频处理模块
|
||||
class VideoProcessor:
|
||||
def __init__(self, model_dir):
|
||||
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True,
|
||||
attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
self.MAX_NUM_FRAMES = 64
|
||||
|
||||
def encode_video(self, video_data):
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
||||
return [l[i] for i in idxs]
|
||||
|
||||
video_file = io.BytesIO(video_data)
|
||||
vr = VideoReader(video_file, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1)
|
||||
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
||||
if len(frame_idx) > self.MAX_NUM_FRAMES:
|
||||
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
||||
print('num frames:', len(frames))
|
||||
return frames
|
||||
|
||||
def process_video(self, video_data, object_name):
|
||||
if not video_data:
|
||||
raise ValueError(f"Empty video data for {object_name}")
|
||||
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
|
||||
frames = self.encode_video(video_data)
|
||||
question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior."
|
||||
msgs = [
|
||||
{'role': 'user', 'content': frames + [question]},
|
||||
]
|
||||
|
||||
params = {
|
||||
"use_image_id": False,
|
||||
"max_slice_nums": 1
|
||||
}
|
||||
|
||||
answer = self.model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=self.tokenizer,
|
||||
**params
|
||||
)
|
||||
|
||||
extracted_info = self.extract_info(answer)
|
||||
|
||||
return {
|
||||
"original_answer": answer,
|
||||
"extracted_info": extracted_info,
|
||||
"num_frames": len(frames),
|
||||
}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def extract_info(answer):
|
||||
info = {
|
||||
"environment": None,
|
||||
"num_people": None,
|
||||
"actions": [],
|
||||
"interactions": [],
|
||||
"objects": [],
|
||||
"furniture": []
|
||||
}
|
||||
|
||||
# 环境提取
|
||||
environments = ["办公室", "室内", "室外", "会议室", "办公"]
|
||||
for env in environments:
|
||||
if env in answer.lower():
|
||||
info["environment"] = env
|
||||
break
|
||||
|
||||
# 改进的人数提取
|
||||
people_patterns = [
|
||||
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
||||
r'(男|女)(性|生|士)',
|
||||
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
||||
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
||||
r'(群众|民众|大众|公众)',
|
||||
r'(男女|老少|老幼|大人|小孩)'
|
||||
]
|
||||
for pattern in people_patterns:
|
||||
match = re.search(pattern, answer)
|
||||
if match:
|
||||
if match.group(1).isdigit():
|
||||
info["num_people"] = int(match.group(1))
|
||||
elif match.group(1) in ['一个', '一']:
|
||||
info["num_people"] = 1
|
||||
else:
|
||||
num_word_to_digit = {
|
||||
'二': 2, '三': 3, '四': 4, '五': 5,
|
||||
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
||||
}
|
||||
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
||||
break
|
||||
|
||||
# 动作和互动提取
|
||||
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
||||
for action in actions:
|
||||
if action in answer:
|
||||
info["actions"].append(action)
|
||||
|
||||
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
||||
for interaction in interactions:
|
||||
if interaction in answer:
|
||||
info["interactions"].append(interaction)
|
||||
|
||||
# 物体和家具提取
|
||||
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
||||
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
||||
|
||||
for obj in objects:
|
||||
if obj in answer:
|
||||
info["objects"].append(obj)
|
||||
|
||||
for item in furniture:
|
||||
if item in answer:
|
||||
info["furniture"].append(item)
|
||||
|
||||
return info
|
||||
|
||||
# 主处理类
|
||||
class VideoAnalysisSystem:
|
||||
def __init__(self, minio_endpoint, minio_access_key, minio_secret_key,
|
||||
mongo_uri, db_name, model_dir, results_collection_name):
|
||||
self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key)
|
||||
self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name)
|
||||
self.video_processor = VideoProcessor(model_dir)
|
||||
|
||||
def process_video(self, video_doc):
|
||||
start_time = time.time()
|
||||
try:
|
||||
video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name'])
|
||||
result = self.video_processor.process_video(video_data, video_doc['object_name'])
|
||||
|
||||
result['etag'] = video_doc['etag']
|
||||
result['bucket_name'] = video_doc['bucket_name']
|
||||
result['object_name'] = video_doc['object_name']
|
||||
|
||||
self.db_handler.save_result(result)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print(f"Processed video: {video_doc['object_name']}")
|
||||
print(f"Processing time: {processing_time:.2f} seconds")
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print(f"Error processing video {video_doc['object_name']}: {str(e)}")
|
||||
print(f"Processing time (including error): {processing_time:.2f} seconds")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
unprocessed_videos = list(self.db_handler.get_unprocessed_videos())
|
||||
|
||||
if not unprocessed_videos:
|
||||
print("No new videos to process. Waiting for 60 seconds before checking again...")
|
||||
time.sleep(60)
|
||||
continue
|
||||
|
||||
for video_doc in unprocessed_videos:
|
||||
self.process_video(video_doc)
|
||||
|
||||
print("Finished processing current batch of videos. Waiting for new videos...")
|
||||
time.sleep(30)
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
minio_endpoint = "api.obscura.work"
|
||||
minio_access_key = "MnHTAG2NOLyXXIZrwDLp"
|
||||
minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf"
|
||||
|
||||
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
|
||||
db_name = "minio_mongo"
|
||||
results_collection_name = "douyin_results"
|
||||
|
||||
model_dir = "MiniCPM-V-2_6"
|
||||
|
||||
system = VideoAnalysisSystem(minio_endpoint, minio_access_key, minio_secret_key,
|
||||
mongo_uri, db_name, model_dir, results_collection_name)
|
||||
system.run()
|
||||
@@ -1,248 +0,0 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from decord import VideoReader, cpu
|
||||
import json
|
||||
import re
|
||||
from pymongo import MongoClient
|
||||
import io
|
||||
from minio import Minio
|
||||
import time
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
# Minio连接模块
|
||||
class MinioHandler:
|
||||
def __init__(self, endpoint, access_key, secret_key):
|
||||
self.client = Minio(
|
||||
endpoint,
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
secure=True
|
||||
)
|
||||
|
||||
def get_video_data(self, bucket, object_name):
|
||||
response = self.client.get_object(bucket, object_name)
|
||||
data = response.read()
|
||||
print(f"Read {len(data)} bytes from Minio for {object_name}")
|
||||
return data
|
||||
|
||||
# 数据库连接模块
|
||||
class DatabaseHandler:
|
||||
def __init__(self, mongo_uri, database_name, results_collection_name):
|
||||
self.client = MongoClient(mongo_uri)
|
||||
self.db = self.client[database_name]
|
||||
self.minio_files_collection = self.db['minio_files']
|
||||
self.results_collection = self.db[results_collection_name]
|
||||
|
||||
def get_unprocessed_videos(self):
|
||||
# 查找 bucket_name 为 'raw-video' 且在结果集合中没有对应 etag 的视频
|
||||
processed_etags = set(self.results_collection.distinct('etag'))
|
||||
return self.minio_files_collection.find({
|
||||
'bucket_name': 'raw',
|
||||
'object_name': {'$regex': 'videoupload/'},
|
||||
'etag': {'$nin': list(processed_etags)}
|
||||
})
|
||||
|
||||
def save_result(self, result):
|
||||
# 检查是否已存在相同 etag 的结果
|
||||
existing_result = self.results_collection.find_one({'etag': result['etag']})
|
||||
if existing_result:
|
||||
print(f"Video with etag {result['etag']} has already been processed. Skipping.")
|
||||
return
|
||||
|
||||
# 将 ObjectId 转换为字符串
|
||||
if 'video_id' in result and isinstance(result['video_id'], ObjectId):
|
||||
result['video_id'] = str(result['video_id'])
|
||||
|
||||
self.results_collection.insert_one(result)
|
||||
|
||||
class JSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, ObjectId):
|
||||
return str(o)
|
||||
return super().default(o)
|
||||
|
||||
|
||||
# 视频处理模块
|
||||
class VideoProcessor:
|
||||
def __init__(self, model_dir):
|
||||
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True,
|
||||
attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
self.MAX_NUM_FRAMES = 64
|
||||
|
||||
def encode_video(self, video_data):
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
||||
return [l[i] for i in idxs]
|
||||
|
||||
video_file = io.BytesIO(video_data)
|
||||
vr = VideoReader(video_file, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1)
|
||||
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
||||
if len(frame_idx) > self.MAX_NUM_FRAMES:
|
||||
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
||||
print('num frames:', len(frames))
|
||||
return frames
|
||||
|
||||
def process_video(self, video_data, object_name):
|
||||
if not video_data:
|
||||
raise ValueError(f"Empty video data for {object_name}")
|
||||
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
|
||||
frames = self.encode_video(video_data)
|
||||
question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior."
|
||||
msgs = [
|
||||
{'role': 'user', 'content': frames + [question]},
|
||||
]
|
||||
|
||||
params = {
|
||||
"use_image_id": False,
|
||||
"max_slice_nums": 1
|
||||
}
|
||||
|
||||
answer = self.model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=self.tokenizer,
|
||||
**params
|
||||
)
|
||||
|
||||
extracted_info = self.extract_info(answer)
|
||||
|
||||
return {
|
||||
"original_answer": answer,
|
||||
"extracted_info": extracted_info,
|
||||
"num_frames": len(frames),
|
||||
}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def extract_info(answer):
|
||||
info = {
|
||||
"environment": None,
|
||||
"num_people": None,
|
||||
"actions": [],
|
||||
"interactions": [],
|
||||
"objects": [],
|
||||
"furniture": []
|
||||
}
|
||||
|
||||
# 环境提取
|
||||
environments = ["办公室", "室内", "室外", "会议室"]
|
||||
for env in environments:
|
||||
if env in answer.lower():
|
||||
info["environment"] = env
|
||||
break
|
||||
|
||||
# 改进的人数提取
|
||||
people_patterns = [
|
||||
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
||||
r'(男|女)(性|生|士)',
|
||||
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
||||
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
||||
r'(群众|民众|大众|公众)',
|
||||
r'(男女|老少|老幼|大人|小孩)'
|
||||
]
|
||||
for pattern in people_patterns:
|
||||
match = re.search(pattern, answer)
|
||||
if match:
|
||||
if match.group(1).isdigit():
|
||||
info["num_people"] = int(match.group(1))
|
||||
elif match.group(1) in ['一个', '一']:
|
||||
info["num_people"] = 1
|
||||
else:
|
||||
num_word_to_digit = {
|
||||
'二': 2, '三': 3, '四': 4, '五': 5,
|
||||
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
||||
}
|
||||
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
||||
break
|
||||
|
||||
# 动作和互动提取
|
||||
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
||||
for action in actions:
|
||||
if action in answer:
|
||||
info["actions"].append(action)
|
||||
|
||||
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
||||
for interaction in interactions:
|
||||
if interaction in answer:
|
||||
info["interactions"].append(interaction)
|
||||
|
||||
# 物体和家具提取
|
||||
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
||||
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
||||
|
||||
for obj in objects:
|
||||
if obj in answer:
|
||||
info["objects"].append(obj)
|
||||
|
||||
for item in furniture:
|
||||
if item in answer:
|
||||
info["furniture"].append(item)
|
||||
|
||||
return info
|
||||
|
||||
# 主处理类
|
||||
class VideoAnalysisSystem:
|
||||
def __init__(self, minio_endpoint, minio_access_key, minio_secret_key,
|
||||
mongo_uri, db_name, model_dir, results_collection_name):
|
||||
self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key)
|
||||
self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name)
|
||||
self.video_processor = VideoProcessor(model_dir)
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
unprocessed_videos = list(self.db_handler.get_unprocessed_videos())
|
||||
|
||||
if not unprocessed_videos:
|
||||
print("No new videos to process. Waiting for 60 seconds before checking again...")
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
for video_doc in unprocessed_videos:
|
||||
try:
|
||||
video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name'])
|
||||
result = self.video_processor.process_video(video_data, video_doc['object_name'])
|
||||
|
||||
# 添加额外信息到结果中
|
||||
result['etag'] = video_doc['etag']
|
||||
# result['video_id'] = str(video_doc['_id']) # 将 ObjectId 转换为字符串
|
||||
result['bucket_name'] = video_doc['bucket_name']
|
||||
result['object_name'] = video_doc['object_name']
|
||||
|
||||
# 保存结果到 MongoDB
|
||||
self.db_handler.save_result(result)
|
||||
|
||||
print(f"Processed video: {video_doc['object_name']}")
|
||||
# print(json.dumps(result, ensure_ascii=False, indent=2, cls=JSONEncoder))
|
||||
except Exception as e:
|
||||
print(f"Error processing video {video_doc['object_name']}: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc() # 打印完整的错误堆栈
|
||||
|
||||
print("Finished processing current batch of videos. Waiting for new videos...")
|
||||
time.sleep(30)
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
minio_endpoint = "api.obscura.work"
|
||||
minio_access_key = "MnHTAG2NOLyXXIZrwDLp"
|
||||
minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf"
|
||||
|
||||
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
|
||||
db_name = "minio_mongo"
|
||||
results_collection_name = "douyin_results"
|
||||
|
||||
model_dir = "OpenBMB/MiniCPM-V-2_6"
|
||||
|
||||
system = VideoAnalysisSystem(minio_endpoint, minio_access_key, minio_secret_key,
|
||||
mongo_uri, db_name, model_dir, results_collection_name)
|
||||
system.run()
|
||||
@@ -1,264 +0,0 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from decord import VideoReader, cpu
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from pymongo import MongoClient
|
||||
import io
|
||||
from minio import Minio
|
||||
import time
|
||||
import os
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
# Minio连接模块
|
||||
class MinioHandler:
|
||||
def __init__(self, endpoint, access_key, secret_key):
|
||||
self.client = Minio(
|
||||
endpoint,
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
secure=True
|
||||
)
|
||||
|
||||
def get_video_data(self, bucket, object_name):
|
||||
response = self.client.get_object(bucket, object_name)
|
||||
return response.read()
|
||||
|
||||
# 数据库连接模块
|
||||
class DatabaseHandler:
|
||||
def __init__(self, mongo_uri, database_name, results_collection_name):
|
||||
self.client = MongoClient(mongo_uri)
|
||||
self.db = self.client[database_name]
|
||||
self.minio_files_collection = self.db['minio_files']
|
||||
self.results_collection = self.db[results_collection_name]
|
||||
|
||||
def get_unprocessed_videos(self):
|
||||
# 查找 bucket_name 为 'raw-video' 且在结果集合中没有对应 etag 的视频
|
||||
processed_etags = set(self.results_collection.distinct('etag'))
|
||||
return self.minio_files_collection.find({
|
||||
'bucket_name': 'raw',
|
||||
'object_name': {'$regex': '/douyin/'},
|
||||
'etag': {'$nin': list(processed_etags)}
|
||||
})
|
||||
|
||||
def save_result(self, result):
|
||||
# 检查是否已存在相同 etag 的结果
|
||||
existing_result = self.results_collection.find_one({'etag': result['etag']})
|
||||
if existing_result:
|
||||
print(f"Video with etag {result['etag']} has already been processed. Skipping.")
|
||||
return
|
||||
|
||||
# 将 ObjectId 转换为字符串
|
||||
if 'video_id' in result and isinstance(result['video_id'], ObjectId):
|
||||
result['video_id'] = str(result['video_id'])
|
||||
|
||||
self.results_collection.insert_one(result)
|
||||
|
||||
class JSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, ObjectId):
|
||||
return str(o)
|
||||
return super().default(o)
|
||||
|
||||
|
||||
# 视频处理模块
|
||||
class VideoProcessor:
|
||||
def __init__(self, model_dir):
|
||||
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True,
|
||||
attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
self.MAX_NUM_FRAMES = 64
|
||||
|
||||
def encode_video(self, video_data):
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
||||
return [l[i] for i in idxs]
|
||||
|
||||
video_file = io.BytesIO(video_data)
|
||||
vr = VideoReader(video_file, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1)
|
||||
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
||||
if len(frame_idx) > self.MAX_NUM_FRAMES:
|
||||
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
||||
print('num frames:', len(frames))
|
||||
return frames
|
||||
|
||||
def process_video(self, video_data, object_name):
|
||||
frames = self.encode_video(video_data)
|
||||
question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior."
|
||||
msgs = [
|
||||
{'role': 'user', 'content': frames + [question]},
|
||||
]
|
||||
start_time, end_time = self.extract_time_from_filename(object_name)
|
||||
|
||||
params = {
|
||||
"use_image_id": False,
|
||||
"max_slice_nums": 1
|
||||
}
|
||||
|
||||
answer = self.model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=self.tokenizer,
|
||||
**params
|
||||
)
|
||||
|
||||
extracted_info = self.extract_info(answer)
|
||||
|
||||
return {
|
||||
"original_answer": answer,
|
||||
"extracted_info": extracted_info,
|
||||
"num_frames": len(frames),
|
||||
"start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"end_time": end_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def extract_time_from_filename(object_name):
|
||||
# 从 object_name 中提取文件名
|
||||
filename = os.path.basename(object_name)
|
||||
|
||||
# 从文件名中提取日期时间部分
|
||||
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
|
||||
|
||||
try:
|
||||
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
|
||||
end_time = start_time + timedelta(seconds=10)
|
||||
return start_time, end_time
|
||||
except ValueError:
|
||||
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
|
||||
return datetime.now(), datetime.now() + timedelta(seconds=10)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def extract_info(answer):
|
||||
info = {
|
||||
"environment": None,
|
||||
"num_people": None,
|
||||
"actions": [],
|
||||
"interactions": [],
|
||||
"objects": [],
|
||||
"furniture": []
|
||||
}
|
||||
|
||||
# 环境提取
|
||||
environments = ["办公室", "室内", "室外", "会议室"]
|
||||
for env in environments:
|
||||
if env in answer.lower():
|
||||
info["environment"] = env
|
||||
break
|
||||
|
||||
# 改进的人数提取
|
||||
people_patterns = [
|
||||
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
||||
r'(男|女)(性|生|士)',
|
||||
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
||||
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
||||
r'(群众|民众|大众|公众)',
|
||||
r'(男女|老少|老幼|大人|小孩)'
|
||||
]
|
||||
for pattern in people_patterns:
|
||||
match = re.search(pattern, answer)
|
||||
if match:
|
||||
if match.group(1).isdigit():
|
||||
info["num_people"] = int(match.group(1))
|
||||
elif match.group(1) in ['一个', '一']:
|
||||
info["num_people"] = 1
|
||||
else:
|
||||
num_word_to_digit = {
|
||||
'二': 2, '三': 3, '四': 4, '五': 5,
|
||||
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
||||
}
|
||||
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
||||
break
|
||||
|
||||
# 动作和互动提取
|
||||
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
||||
for action in actions:
|
||||
if action in answer:
|
||||
info["actions"].append(action)
|
||||
|
||||
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
||||
for interaction in interactions:
|
||||
if interaction in answer:
|
||||
info["interactions"].append(interaction)
|
||||
|
||||
# 物体和家具提取
|
||||
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
||||
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
||||
|
||||
for obj in objects:
|
||||
if obj in answer:
|
||||
info["objects"].append(obj)
|
||||
|
||||
for item in furniture:
|
||||
if item in answer:
|
||||
info["furniture"].append(item)
|
||||
|
||||
return info
|
||||
|
||||
# 主处理类
|
||||
class VideoAnalysisSystem:
|
||||
def __init__(self, minio_endpoint, minio_access_key, minio_secret_key,
|
||||
mongo_uri, db_name, model_dir, results_collection_name):
|
||||
self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key)
|
||||
self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name)
|
||||
self.video_processor = VideoProcessor(model_dir)
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
unprocessed_videos = list(self.db_handler.get_unprocessed_videos())
|
||||
|
||||
if not unprocessed_videos:
|
||||
print("No new videos to process. Waiting for 60 seconds before checking again...")
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
for video_doc in unprocessed_videos:
|
||||
try:
|
||||
video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name'])
|
||||
result = self.video_processor.process_video(video_data, video_doc['object_name'])
|
||||
|
||||
# 添加额外信息到结果中
|
||||
result['etag'] = video_doc['etag']
|
||||
# result['video_id'] = str(video_doc['_id']) # 将 ObjectId 转换为字符串
|
||||
result['bucket_name'] = video_doc['bucket_name']
|
||||
result['object_name'] = video_doc['object_name']
|
||||
|
||||
# 保存结果到 MongoDB
|
||||
self.db_handler.save_result(result)
|
||||
|
||||
print(f"Processed video: {video_doc['object_name']}")
|
||||
# print(json.dumps(result, ensure_ascii=False, indent=2, cls=JSONEncoder))
|
||||
except Exception as e:
|
||||
print(f"Error processing video {video_doc['object_name']}: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc() # 打印完整的错误堆栈
|
||||
|
||||
print("Finished processing current batch of videos. Waiting for new videos...")
|
||||
time.sleep(30)
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
minio_endpoint = "api.obscura.work"
|
||||
minio_access_key = "MnHTAG2NOLyXXIZrwDLp"
|
||||
minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf"
|
||||
|
||||
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
|
||||
db_name = "minio_mongo"
|
||||
results_collection_name = "douyin_results"
|
||||
|
||||
model_dir = "OpenBMB/MiniCPM-V-2_6"
|
||||
|
||||
system = VideoAnalysisSystem(minio_endpoint, minio_access_key, minio_secret_key,
|
||||
mongo_uri, db_name, model_dir, results_collection_name)
|
||||
system.run()
|
||||
@@ -1,121 +0,0 @@
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from bson import ObjectId
|
||||
from minio import Minio
|
||||
from pymongo import MongoClient
|
||||
import whisper
|
||||
|
||||
class MinioHandler:
|
||||
def __init__(self, endpoint, access_key, secret_key):
|
||||
self.client = Minio(
|
||||
endpoint,
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
secure=True
|
||||
)
|
||||
|
||||
def get_video_data(self, bucket, object_name):
|
||||
response = self.client.get_object(bucket, object_name)
|
||||
data = response.read()
|
||||
print(f"Read {len(data)} bytes from Minio for {object_name}")
|
||||
return data
|
||||
|
||||
class DatabaseHandler:
|
||||
def __init__(self, mongo_uri, database_name, collection_name):
|
||||
self.client = MongoClient(mongo_uri)
|
||||
self.db = self.client[database_name]
|
||||
self.collection = self.db[collection_name]
|
||||
|
||||
def get_unprocessed_videos(self):
|
||||
return self.collection.find({
|
||||
'bucket_name': 'raw',
|
||||
'object_name': {'$regex': r'^douyin/.*/.+\.(mp4|avi|mov|flv)$'},
|
||||
'whisper_transcription': {'$exists': False}
|
||||
})
|
||||
|
||||
def update_transcription(self, video_id, transcription):
|
||||
self.collection.update_one(
|
||||
{'_id': video_id},
|
||||
{'$set': {'whisper_transcription': transcription}}
|
||||
)
|
||||
|
||||
class WhisperProcessor:
|
||||
def __init__(self, model_name, model_path=None):
|
||||
if model_path:
|
||||
self.model = whisper.load_model(model_name, download_root=model_path)
|
||||
else:
|
||||
self.model = whisper.load_model(model_name)
|
||||
|
||||
def transcribe_audio(self, video_data):
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
|
||||
temp_video.write(video_data)
|
||||
temp_video_path = temp_video.name
|
||||
|
||||
try:
|
||||
result = self.model.transcribe(temp_video_path)
|
||||
return result["text"]
|
||||
finally:
|
||||
os.unlink(temp_video_path)
|
||||
|
||||
class WhisperTranscriptionSystem:
|
||||
def __init__(self, minio_endpoint, minio_access_key, minio_secret_key,
|
||||
mongo_uri, db_name, collection_name, whisper_model_name, whisper_model_path=None):
|
||||
self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key)
|
||||
self.db_handler = DatabaseHandler(mongo_uri, db_name, collection_name)
|
||||
self.whisper_processor = WhisperProcessor(whisper_model_name, whisper_model_path)
|
||||
|
||||
def process_video(self, video_doc):
|
||||
start_time = time.time()
|
||||
try:
|
||||
video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name'])
|
||||
transcription = self.whisper_processor.transcribe_audio(video_data)
|
||||
|
||||
self.db_handler.update_transcription(video_doc['_id'], transcription)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print(f"Processed video: {video_doc['object_name']}")
|
||||
print(f"Processing time: {processing_time:.2f} seconds")
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print(f"Error processing video {video_doc['object_name']}: {str(e)}")
|
||||
print(f"Processing time (including error): {processing_time:.2f} seconds")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
unprocessed_videos = list(self.db_handler.get_unprocessed_videos())
|
||||
|
||||
if not unprocessed_videos:
|
||||
print("No new videos to process. Waiting for 60 seconds before checking again...")
|
||||
time.sleep(60)
|
||||
continue
|
||||
|
||||
for video_doc in unprocessed_videos:
|
||||
self.process_video(video_doc)
|
||||
|
||||
print("Finished processing current batch of videos. Waiting for new videos...")
|
||||
time.sleep(30)
|
||||
|
||||
if __name__ == "__main__":
|
||||
minio_endpoint = "api.obscura.work"
|
||||
minio_access_key = "MnHTAG2NOLyXXIZrwDLp"
|
||||
minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf"
|
||||
|
||||
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
|
||||
db_name = "minio_mongo"
|
||||
collection_name = "douyin_results"
|
||||
|
||||
whisper_model_name = "large-v3" # 指定模型名称
|
||||
whisper_model_path = "whisper" # 指定模型存放路径
|
||||
|
||||
system = WhisperTranscriptionSystem(minio_endpoint, minio_access_key, minio_secret_key,
|
||||
mongo_uri, db_name, collection_name,
|
||||
whisper_model_name, whisper_model_path)
|
||||
system.run()
|
||||
@@ -1,114 +0,0 @@
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from bson import ObjectId
|
||||
from minio import Minio
|
||||
from pymongo import MongoClient
|
||||
import whisper
|
||||
|
||||
class MinioHandler:
|
||||
def __init__(self, endpoint, access_key, secret_key):
|
||||
self.client = Minio(
|
||||
endpoint,
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
secure=True
|
||||
)
|
||||
|
||||
def get_video_data(self, bucket, object_name):
|
||||
response = self.client.get_object(bucket, object_name)
|
||||
data = response.read()
|
||||
print(f"Read {len(data)} bytes from Minio for {object_name}")
|
||||
return data
|
||||
|
||||
class DatabaseHandler:
|
||||
def __init__(self, mongo_uri, database_name, collection_name):
|
||||
self.client = MongoClient(mongo_uri)
|
||||
self.db = self.client[database_name]
|
||||
self.collection = self.db[collection_name]
|
||||
|
||||
def get_unprocessed_videos(self):
|
||||
return self.collection.find({
|
||||
'bucket_name': 'raw',
|
||||
'object_name': {'$regex': 'douyin/'},
|
||||
'whisper_transcription': {'$exists': False}
|
||||
})
|
||||
|
||||
def update_transcription(self, video_id, transcription):
|
||||
self.collection.update_one(
|
||||
{'_id': video_id},
|
||||
{'$set': {'whisper_transcription': transcription}}
|
||||
)
|
||||
|
||||
class WhisperProcessor:
|
||||
def __init__(self, model_name="large-v3"):
|
||||
self.model = whisper.load_model(model_name)
|
||||
|
||||
def transcribe_audio(self, video_data):
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
|
||||
temp_video.write(video_data)
|
||||
temp_video_path = temp_video.name
|
||||
|
||||
try:
|
||||
result = self.model.transcribe(temp_video_path)
|
||||
return result["text"]
|
||||
finally:
|
||||
os.unlink(temp_video_path)
|
||||
|
||||
class WhisperTranscriptionSystem:
|
||||
def __init__(self, minio_endpoint, minio_access_key, minio_secret_key,
|
||||
mongo_uri, db_name, collection_name):
|
||||
self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key)
|
||||
self.db_handler = DatabaseHandler(mongo_uri, db_name, collection_name)
|
||||
self.whisper_processor = WhisperProcessor()
|
||||
|
||||
def process_video(self, video_doc):
|
||||
start_time = time.time()
|
||||
try:
|
||||
video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name'])
|
||||
transcription = self.whisper_processor.transcribe_audio(video_data)
|
||||
|
||||
self.db_handler.update_transcription(video_doc['_id'], transcription)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print(f"Processed video: {video_doc['object_name']}")
|
||||
print(f"Processing time: {processing_time:.2f} seconds")
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print(f"Error processing video {video_doc['object_name']}: {str(e)}")
|
||||
print(f"Processing time (including error): {processing_time:.2f} seconds")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
unprocessed_videos = list(self.db_handler.get_unprocessed_videos())
|
||||
|
||||
if not unprocessed_videos:
|
||||
print("No new videos to process. Waiting for 60 seconds before checking again...")
|
||||
time.sleep(60)
|
||||
continue
|
||||
|
||||
print(f"Found {len(unprocessed_videos)} videos to process.")
|
||||
for video_doc in unprocessed_videos:
|
||||
self.process_video(video_doc)
|
||||
|
||||
print("Finished processing current batch of videos. Checking for more...")
|
||||
|
||||
if __name__ == "__main__":
|
||||
minio_endpoint = "api.obscura.work"
|
||||
minio_access_key = "MnHTAG2NOLyXXIZrwDLp"
|
||||
minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf"
|
||||
|
||||
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
|
||||
db_name = "minio_mongo"
|
||||
collection_name = "douyin_results"
|
||||
|
||||
system = WhisperTranscriptionSystem(minio_endpoint, minio_access_key, minio_secret_key,
|
||||
mongo_uri, db_name, collection_name)
|
||||
system.run()
|
||||
@@ -1,114 +0,0 @@
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from bson import ObjectId
|
||||
from minio import Minio
|
||||
from pymongo import MongoClient
|
||||
import whisper
|
||||
|
||||
class MinioHandler:
|
||||
def __init__(self, endpoint, access_key, secret_key):
|
||||
self.client = Minio(
|
||||
endpoint,
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
secure=True
|
||||
)
|
||||
|
||||
def get_video_data(self, bucket, object_name):
|
||||
response = self.client.get_object(bucket, object_name)
|
||||
data = response.read()
|
||||
print(f"Read {len(data)} bytes from Minio for {object_name}")
|
||||
return data
|
||||
|
||||
class DatabaseHandler:
|
||||
def __init__(self, mongo_uri, database_name, collection_name):
|
||||
self.client = MongoClient(mongo_uri)
|
||||
self.db = self.client[database_name]
|
||||
self.collection = self.db[collection_name]
|
||||
|
||||
def get_unprocessed_videos(self):
|
||||
return self.collection.find({
|
||||
'bucket_name': 'raw',
|
||||
'object_name': {'$regex': 'douyin/'},
|
||||
'whisper_transcription': {'$exists': False}
|
||||
})
|
||||
|
||||
def update_transcription(self, video_id, transcription):
|
||||
self.collection.update_one(
|
||||
{'_id': video_id},
|
||||
{'$set': {'whisper_transcription': transcription}}
|
||||
)
|
||||
|
||||
class WhisperProcessor:
|
||||
def __init__(self, model_name="large-v3"):
|
||||
self.model = whisper.load_model(model_name)
|
||||
|
||||
def transcribe_audio(self, video_data):
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
|
||||
temp_video.write(video_data)
|
||||
temp_video_path = temp_video.name
|
||||
|
||||
try:
|
||||
result = self.model.transcribe(temp_video_path)
|
||||
return result["text"]
|
||||
finally:
|
||||
os.unlink(temp_video_path)
|
||||
|
||||
class WhisperTranscriptionSystem:
|
||||
def __init__(self, minio_endpoint, minio_access_key, minio_secret_key,
|
||||
mongo_uri, db_name, collection_name):
|
||||
self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key)
|
||||
self.db_handler = DatabaseHandler(mongo_uri, db_name, collection_name)
|
||||
self.whisper_processor = WhisperProcessor()
|
||||
|
||||
def process_video(self, video_doc):
|
||||
start_time = time.time()
|
||||
try:
|
||||
video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name'])
|
||||
transcription = self.whisper_processor.transcribe_audio(video_data)
|
||||
|
||||
self.db_handler.update_transcription(video_doc['_id'], transcription)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print(f"Processed video: {video_doc['object_name']}")
|
||||
print(f"Processing time: {processing_time:.2f} seconds")
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print(f"Error processing video {video_doc['object_name']}: {str(e)}")
|
||||
print(f"Processing time (including error): {processing_time:.2f} seconds")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
unprocessed_videos = list(self.db_handler.get_unprocessed_videos())
|
||||
|
||||
if not unprocessed_videos:
|
||||
print("No new videos to process. Waiting for 60 seconds before checking again...")
|
||||
time.sleep(60)
|
||||
continue
|
||||
|
||||
print(f"Found {len(unprocessed_videos)} videos to process.")
|
||||
for video_doc in unprocessed_videos:
|
||||
self.process_video(video_doc)
|
||||
|
||||
print("Finished processing current batch of videos. Checking for more...")
|
||||
|
||||
if __name__ == "__main__":
|
||||
minio_endpoint = "api.obscura.work"
|
||||
minio_access_key = "MnHTAG2NOLyXXIZrwDLp"
|
||||
minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf"
|
||||
|
||||
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
|
||||
db_name = "minio_mongo"
|
||||
collection_name = "douyin_results"
|
||||
|
||||
system = WhisperTranscriptionSystem(minio_endpoint, minio_access_key, minio_secret_key,
|
||||
mongo_uri, db_name, collection_name)
|
||||
system.run()
|
||||
@@ -1,256 +0,0 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from decord import VideoReader, cpu
|
||||
import json
|
||||
import re
|
||||
from pymongo import MongoClient
|
||||
import io
|
||||
from minio import Minio
|
||||
import time
|
||||
from bson import ObjectId
|
||||
import concurrent.futures
|
||||
import os
|
||||
|
||||
class MinioHandler:
|
||||
def __init__(self, endpoint, access_key, secret_key, secure=True):
|
||||
self.client = Minio(
|
||||
endpoint,
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
secure=secure
|
||||
)
|
||||
|
||||
def list_objects(self, bucket_name, prefix):
|
||||
objects = self.client.list_objects(bucket_name, prefix=prefix, recursive=True)
|
||||
return [obj for obj in objects if obj.object_name.lower().endswith(('.mp4', '.avi', '.mov', '.flv'))]
|
||||
|
||||
def get_video_data(self, bucket_name, object_name):
|
||||
try:
|
||||
response = self.client.get_object(bucket_name, object_name)
|
||||
return response.read()
|
||||
except Exception as e:
|
||||
print(f"Error retrieving video data for {object_name}: {str(e)}")
|
||||
return None
|
||||
|
||||
class DatabaseHandler:
|
||||
def __init__(self, mongo_uri, database_name, results_collection_name):
|
||||
self.client = MongoClient(mongo_uri)
|
||||
self.db = self.client[database_name]
|
||||
self.results_collection = self.db[results_collection_name]
|
||||
|
||||
def get_unprocessed_videos(self, minio_handler, bucket_name='raw', prefix='videoupload/'):
|
||||
all_objects = minio_handler.list_objects(bucket_name, prefix)
|
||||
processed_etags = set(self.results_collection.distinct('etag'))
|
||||
|
||||
unprocessed_videos = [
|
||||
{
|
||||
'bucket_name': bucket_name,
|
||||
'object_name': obj.object_name,
|
||||
'etag': obj.etag,
|
||||
'size': obj.size,
|
||||
'last_modified': obj.last_modified
|
||||
}
|
||||
for obj in all_objects if obj.etag not in processed_etags
|
||||
]
|
||||
|
||||
return unprocessed_videos
|
||||
|
||||
def save_result(self, result):
|
||||
existing_result = self.results_collection.find_one({'etag': result['etag']})
|
||||
if existing_result:
|
||||
print(f"Video with etag {result['etag']} has already been processed. Skipping.")
|
||||
return
|
||||
|
||||
if 'video_id' in result and isinstance(result['video_id'], ObjectId):
|
||||
result['video_id'] = str(result['video_id'])
|
||||
|
||||
self.results_collection.insert_one(result)
|
||||
|
||||
class JSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if isinstance(o, ObjectId):
|
||||
return str(o)
|
||||
return super().default(o)
|
||||
|
||||
class VideoProcessor:
|
||||
def __init__(self, model_dir):
|
||||
self.model = AutoModel.from_pretrained(model_dir, trust_remote_code=True,
|
||||
attn_implementation='sdpa', torch_dtype=torch.bfloat16).eval().cuda()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
self.MAX_NUM_FRAMES = 12
|
||||
|
||||
def encode_video(self, video_data):
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
return [l[int(i * gap + gap / 2)] for i in range(n)]
|
||||
|
||||
video_file = io.BytesIO(video_data)
|
||||
vr = VideoReader(video_file, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1)
|
||||
frame_idx = list(range(0, len(vr), sample_fps))
|
||||
if len(frame_idx) > self.MAX_NUM_FRAMES:
|
||||
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
||||
print('num frames:', len(frames))
|
||||
return frames
|
||||
|
||||
def process_video(self, video_data, object_name):
|
||||
if not video_data:
|
||||
raise ValueError(f"Empty video data for {object_name}")
|
||||
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
|
||||
frames = self.encode_video(video_data)
|
||||
question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior."
|
||||
msgs = [
|
||||
{'role': 'user', 'content': frames + [question]},
|
||||
]
|
||||
|
||||
params = {
|
||||
"use_image_id": False,
|
||||
"max_slice_nums": 1
|
||||
}
|
||||
|
||||
answer = self.model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=self.tokenizer,
|
||||
**params
|
||||
)
|
||||
|
||||
extracted_info = self.extract_info(answer)
|
||||
|
||||
return {
|
||||
"original_answer": answer,
|
||||
"extracted_info": extracted_info,
|
||||
"num_frames": len(frames),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def extract_info(answer):
|
||||
info = {
|
||||
"environment": None,
|
||||
"num_people": None,
|
||||
"actions": [],
|
||||
"interactions": [],
|
||||
"objects": [],
|
||||
"furniture": []
|
||||
}
|
||||
|
||||
environments = ["办公室", "室内", "室外", "会议室", "办公"]
|
||||
for env in environments:
|
||||
if env in answer.lower():
|
||||
info["environment"] = env
|
||||
break
|
||||
|
||||
people_patterns = [
|
||||
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
||||
r'(男|女)(性|生|士)',
|
||||
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
||||
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
||||
r'(群众|民众|大众|公众)',
|
||||
r'(男女|老少|老幼|大人|小孩)'
|
||||
]
|
||||
for pattern in people_patterns:
|
||||
match = re.search(pattern, answer)
|
||||
if match:
|
||||
if match.group(1).isdigit():
|
||||
info["num_people"] = int(match.group(1))
|
||||
elif match.group(1) in ['一个', '一']:
|
||||
info["num_people"] = 1
|
||||
else:
|
||||
num_word_to_digit = {
|
||||
'二': 2, '三': 3, '四': 4, '五': 5,
|
||||
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
||||
}
|
||||
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
||||
break
|
||||
|
||||
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
||||
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
||||
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
||||
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
||||
|
||||
for action in actions:
|
||||
if action in answer:
|
||||
info["actions"].append(action)
|
||||
|
||||
for interaction in interactions:
|
||||
if interaction in answer:
|
||||
info["interactions"].append(interaction)
|
||||
|
||||
for obj in objects:
|
||||
if obj in answer:
|
||||
info["objects"].append(obj)
|
||||
|
||||
for item in furniture:
|
||||
if item in answer:
|
||||
info["furniture"].append(item)
|
||||
|
||||
return info
|
||||
|
||||
class VideoAnalysisSystem:
|
||||
def __init__(self, minio_endpoint, minio_access_key, minio_secret_key,
|
||||
mongo_uri, db_name, model_dir, results_collection_name):
|
||||
self.minio_handler = MinioHandler(minio_endpoint, minio_access_key, minio_secret_key)
|
||||
self.db_handler = DatabaseHandler(mongo_uri, db_name, results_collection_name)
|
||||
self.video_processor = VideoProcessor(model_dir)
|
||||
|
||||
def process_video(self, video_doc):
|
||||
start_time = time.time()
|
||||
try:
|
||||
video_data = self.minio_handler.get_video_data(video_doc['bucket_name'], video_doc['object_name'])
|
||||
result = self.video_processor.process_video(video_data, video_doc['object_name'])
|
||||
|
||||
result['etag'] = video_doc['etag']
|
||||
result['bucket_name'] = video_doc['bucket_name']
|
||||
result['object_name'] = video_doc['object_name']
|
||||
|
||||
self.db_handler.save_result(result)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print(f"Processed video: {video_doc['object_name']}")
|
||||
print(f"Processing time: {processing_time:.2f} seconds")
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print(f"Error processing video {video_doc['object_name']}: {str(e)}")
|
||||
print(f"Processing time (including error): {processing_time:.2f} seconds")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
unprocessed_videos = self.db_handler.get_unprocessed_videos(self.minio_handler)
|
||||
|
||||
if not unprocessed_videos:
|
||||
print("No new videos to process. Waiting for 5 seconds before checking again...")
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
for video_doc in unprocessed_videos:
|
||||
self.process_video(video_doc)
|
||||
|
||||
print("Finished processing current batch of videos. Waiting for new videos...")
|
||||
time.sleep(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
minio_endpoint = "api.obscura.work"
|
||||
minio_access_key = "MnHTAG2NOLyXXIZrwDLp"
|
||||
minio_secret_key = "WVlmMgww0aRIU43pCJ1XCjubXQO6YsbHysxX2hBf"
|
||||
|
||||
mongo_uri = "mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/minio_mongo"
|
||||
db_name = "minio_mongo"
|
||||
results_collection_name = "videoupload_results"
|
||||
|
||||
model_dir = "MiniCPM-V-2_6"
|
||||
|
||||
system = VideoAnalysisSystem(minio_endpoint, minio_access_key, minio_secret_key,
|
||||
mongo_uri, db_name, model_dir, results_collection_name)
|
||||
system.run()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,369 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
from redis import Redis
|
||||
import io
|
||||
import re
|
||||
import torch
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import threading
|
||||
|
||||
app = FastAPI()
|
||||
cpm_app = FastAPI()
|
||||
app.mount("/cpm", cpm_app)
|
||||
|
||||
# CORS设置
|
||||
ALLOWED_ORIGINS = ['https://beta.obscura.work']
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/worker_sys/OpenBMB/MiniCPM-V-2_6"
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "cpm"
|
||||
KAFKA_GROUP_ID = "cpm_group"
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 5
|
||||
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
# 设置 GPU 设备
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# 初始化模型
|
||||
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
||||
model = model.half().cuda().eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
||||
|
||||
class MediaAnalysisSystem:
|
||||
def __init__(self, model, tokenizer):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = torch.device("cuda:0")
|
||||
self.model = self.model.to(self.device)
|
||||
self.MAX_NUM_FRAMES = 16
|
||||
|
||||
def encode_video(self, video_data):
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
return [l[int(i * gap + gap / 2)] for i in range(n)]
|
||||
|
||||
video_file = io.BytesIO(video_data)
|
||||
vr = VideoReader(video_file, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1)
|
||||
frame_idx = list(range(0, len(vr), sample_fps))
|
||||
if len(frame_idx) > self.MAX_NUM_FRAMES:
|
||||
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
||||
print('num frames:', len(frames))
|
||||
return frames
|
||||
|
||||
|
||||
def process_video(self, video_data, object_name):
|
||||
if not video_data:
|
||||
raise ValueError(f"Empty video data for {object_name}")
|
||||
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
|
||||
frames = self.encode_video(video_data)
|
||||
question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior."
|
||||
msgs = [
|
||||
{'role': 'user', 'content': frames + [question]},
|
||||
]
|
||||
|
||||
params = {
|
||||
"use_image_id": False,
|
||||
"max_slice_nums": 1
|
||||
}
|
||||
answer = self.model.chat(
|
||||
image=frames, # 直接传递 frames
|
||||
msgs=msgs,
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=512,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
**params
|
||||
)
|
||||
extracted_info = self.extract_info(answer)
|
||||
|
||||
return {
|
||||
"original_answer": answer,
|
||||
"extracted_info": extracted_info,
|
||||
"num_frames": len(frames),
|
||||
# "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
# "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
def process_image(self, image_data, object_name):
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
question = "描述这张图片,包括场景、人物数量和行为等细节。"
|
||||
msgs = [
|
||||
{'role': 'user', 'content': [image] + [question]},
|
||||
]
|
||||
|
||||
params = {
|
||||
"use_image_id": False,
|
||||
"max_slice_nums": 1
|
||||
}
|
||||
|
||||
answer = self.model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=512,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
**params
|
||||
)
|
||||
|
||||
extracted_info = self.extract_info(answer)
|
||||
|
||||
return {
|
||||
"original_answer": answer,
|
||||
"extracted_info": extracted_info,
|
||||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def extract_time_from_filename(object_name):
|
||||
filename = os.path.basename(object_name)
|
||||
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
|
||||
|
||||
try:
|
||||
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
|
||||
end_time = start_time + timedelta(seconds=10)
|
||||
return start_time, end_time
|
||||
except ValueError:
|
||||
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
|
||||
return datetime.now(), datetime.now() + timedelta(seconds=10)
|
||||
|
||||
@staticmethod
|
||||
def extract_info(answer):
|
||||
info = {
|
||||
"environment": None,
|
||||
"num_people": None,
|
||||
"actions": [],
|
||||
"interactions": [],
|
||||
"objects": [],
|
||||
"furniture": []
|
||||
}
|
||||
|
||||
environments = ["办公室", "室内", "室外", "会议室"]
|
||||
for env in environments:
|
||||
if env in answer.lower():
|
||||
info["environment"] = env
|
||||
break
|
||||
|
||||
people_patterns = [
|
||||
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
||||
r'(男|女)(性|生|士)',
|
||||
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
||||
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
||||
r'(群众|民众|大众|公众)',
|
||||
r'(男女|老少|老幼|大人|小孩)'
|
||||
]
|
||||
for pattern in people_patterns:
|
||||
match = re.search(pattern, answer)
|
||||
if match:
|
||||
if match.group(1).isdigit():
|
||||
info["num_people"] = int(match.group(1))
|
||||
elif match.group(1) in ['一个', '一']:
|
||||
info["num_people"] = 1
|
||||
else:
|
||||
num_word_to_digit = {
|
||||
'二': 2, '三': 3, '四': 4, '五': 5,
|
||||
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
||||
}
|
||||
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
||||
break
|
||||
|
||||
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
||||
for action in actions:
|
||||
if action in answer:
|
||||
info["actions"].append(action)
|
||||
|
||||
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
||||
for interaction in interactions:
|
||||
if interaction in answer:
|
||||
info["interactions"].append(interaction)
|
||||
|
||||
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
||||
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
||||
|
||||
for obj in objects:
|
||||
if obj in answer:
|
||||
info["objects"].append(obj)
|
||||
|
||||
for item in furniture:
|
||||
if item in answer:
|
||||
info["furniture"].append(item)
|
||||
|
||||
return info
|
||||
|
||||
# 初始化 MediaAnalysisSystem
|
||||
media_analysis_system = MediaAnalysisSystem(model, tokenizer)
|
||||
|
||||
async def process_file(file: UploadFile, file_type: str):
|
||||
content = await file.read()
|
||||
# 获取原始文件的后缀
|
||||
original_extension = os.path.splitext(file.filename)[1]
|
||||
|
||||
# 生成新的文件名,包含 UUID 和原始后缀
|
||||
filename = f"cpm_{uuid.uuid4()}{original_extension}"
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": filename,
|
||||
"type": file_type
|
||||
}).encode('utf-8'))
|
||||
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
return {"message": f"{file_type.capitalize()} uploaded and queued for processing", "filename": filename}
|
||||
|
||||
@cpm_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
try:
|
||||
file_type = "image" if file.content_type.startswith("image") else "video"
|
||||
return await process_file(file, file_type)
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
@cpm_app.post("/analyze_video")
|
||||
async def analyze_video(file: UploadFile = File(...)):
|
||||
try:
|
||||
return await process_file(file, "video")
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
@cpm_app.post("/analyze_image")
|
||||
async def analyze_image(file: UploadFile = File(...)):
|
||||
try:
|
||||
return await process_file(file, "image")
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
try:
|
||||
if isinstance(message.value, dict):
|
||||
task = message.value
|
||||
else:
|
||||
task = json.loads(message.value.decode('utf-8'))
|
||||
|
||||
filename = task['filename']
|
||||
file_type = task['type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
file_data = f.read()
|
||||
|
||||
if file_type == "video":
|
||||
result = media_analysis_system.process_video(file_data, filename)
|
||||
elif file_type == "image":
|
||||
result = media_analysis_system.process_image(file_data, filename)
|
||||
|
||||
# 保存结果到 JSON 文件
|
||||
result_file_path = os.path.join(RESULT_DIR, f"{filename}.json")
|
||||
with open(result_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 将结果存储在 Redis 中
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"status": "completed",
|
||||
"result": result
|
||||
}))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
if 'filename' in locals() and 'file_type' in locals():
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
else:
|
||||
print("Error occurred before task details were extracted")
|
||||
|
||||
@cpm_app.get("/result/{filename}")
|
||||
async def get_result(filename: str):
|
||||
for file_type in ["video", "image"]:
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_json = json.loads(result)
|
||||
|
||||
if result_json.get("status") == "queued":
|
||||
return {"status": "queued", "message": "Your request is in the queue and will be processed soon."}
|
||||
elif result_json.get("status") == "processing":
|
||||
return {"status": "processing", "message": "Your request is being processed."}
|
||||
else:
|
||||
return result_json
|
||||
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
async def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@5__:*_result:*')
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
print(f"Key changed: {key}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 在后台线程中启动Kafka消费者
|
||||
consumer_thread = threading.Thread(target=process_task, daemon=True)
|
||||
consumer_thread.start()
|
||||
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=7000)
|
||||
@@ -1,518 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, Header
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
from redis import Redis
|
||||
import io
|
||||
import re
|
||||
import torch
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import threading
|
||||
|
||||
app = FastAPI()
|
||||
cpm_app = FastAPI()
|
||||
app.mount("/cpm", cpm_app)
|
||||
|
||||
# CORS设置
|
||||
ALLOWED_ORIGINS = ['https://beta.obscura.work']
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/worker_sys/OpenBMB/MiniCPM-V-2_6"
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "cpm"
|
||||
KAFKA_GROUP_ID = "cpm_group"
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 5
|
||||
REDIS_API_DB = 12
|
||||
REDIS_API_USAGE_DB = 13
|
||||
|
||||
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
redis_api_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_DB
|
||||
)
|
||||
|
||||
redis_api_usage_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_USAGE_DB
|
||||
)
|
||||
|
||||
# 添加API密钥验证
|
||||
API_KEY_NAME = "X-API-Key"
|
||||
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
||||
|
||||
async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)):
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=400, detail="API密钥缺失")
|
||||
return api_key
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
return None
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
return None
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
return None
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
# 更新总的token使用量
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
# 更新特定模型的token使用量
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||
base_tokens = 0
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||
|
||||
# 基础token:每MB文件大小消耗10个token
|
||||
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
||||
|
||||
if file_type == "image":
|
||||
img = Image.open(file_path)
|
||||
width, height = img.size
|
||||
pixel_count = width * height
|
||||
|
||||
# 图片token:每100个像素额外消耗5个token
|
||||
image_tokens = int((pixel_count / 10000) * 5)
|
||||
|
||||
base_tokens += image_tokens
|
||||
|
||||
elif file_type == "video":
|
||||
vr = VideoReader(file_path)
|
||||
fps = vr.get_avg_fps()
|
||||
frame_count = len(vr)
|
||||
width, height = vr[0].shape[1], vr[0].shape[0]
|
||||
|
||||
pixel_count = width * height * frame_count
|
||||
duration = frame_count / fps # 视频时长(秒)
|
||||
|
||||
# 视频token:每100万像素每秒额外消耗1个token
|
||||
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
||||
|
||||
base_tokens += video_tokens
|
||||
|
||||
return max(1, base_tokens) # 确保至少返回1个token
|
||||
except Exception as e:
|
||||
print(f"计算token时出错: {str(e)}")
|
||||
return 1 # 出错时返回默认值1
|
||||
|
||||
# 设置 GPU 设备
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# 初始化模型
|
||||
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
||||
model = model.half().cuda().eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
||||
|
||||
class MediaAnalysisSystem:
|
||||
def __init__(self, model, tokenizer):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = torch.device("cuda:0")
|
||||
self.model = self.model.to(self.device)
|
||||
self.MAX_NUM_FRAMES = 16
|
||||
|
||||
def encode_video(self, video_data):
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
return [l[int(i * gap + gap / 2)] for i in range(n)]
|
||||
|
||||
video_file = io.BytesIO(video_data)
|
||||
vr = VideoReader(video_file, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1)
|
||||
frame_idx = list(range(0, len(vr), sample_fps))
|
||||
if len(frame_idx) > self.MAX_NUM_FRAMES:
|
||||
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
||||
print('num frames:', len(frames))
|
||||
return frames
|
||||
|
||||
|
||||
def process_video(self, video_data, object_name):
|
||||
if not video_data:
|
||||
raise ValueError(f"Empty video data for {object_name}")
|
||||
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
|
||||
frames = self.encode_video(video_data)
|
||||
question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior."
|
||||
msgs = [
|
||||
{'role': 'user', 'content': frames + [question]},
|
||||
]
|
||||
|
||||
params = {
|
||||
"use_image_id": False,
|
||||
"max_slice_nums": 1
|
||||
}
|
||||
answer = self.model.chat(
|
||||
image=frames, # 直接传递 frames
|
||||
msgs=msgs,
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=512,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
**params
|
||||
)
|
||||
extracted_info = self.extract_info(answer)
|
||||
|
||||
return {
|
||||
"original_answer": answer,
|
||||
"extracted_info": extracted_info,
|
||||
"num_frames": len(frames),
|
||||
# "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
# "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
def process_image(self, image_data, object_name):
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
question = "描述这张图片,包括场景、人物数量和行为等细节。"
|
||||
msgs = [
|
||||
{'role': 'user', 'content': [image] + [question]},
|
||||
]
|
||||
|
||||
params = {
|
||||
"use_image_id": False,
|
||||
"max_slice_nums": 1
|
||||
}
|
||||
|
||||
answer = self.model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=512,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
**params
|
||||
)
|
||||
|
||||
extracted_info = self.extract_info(answer)
|
||||
|
||||
return {
|
||||
"original_answer": answer,
|
||||
"extracted_info": extracted_info,
|
||||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def extract_time_from_filename(object_name):
|
||||
filename = os.path.basename(object_name)
|
||||
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
|
||||
|
||||
try:
|
||||
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
|
||||
end_time = start_time + timedelta(seconds=10)
|
||||
return start_time, end_time
|
||||
except ValueError:
|
||||
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
|
||||
return datetime.now(), datetime.now() + timedelta(seconds=10)
|
||||
|
||||
@staticmethod
|
||||
def extract_info(answer):
|
||||
info = {
|
||||
"environment": None,
|
||||
"num_people": None,
|
||||
"actions": [],
|
||||
"interactions": [],
|
||||
"objects": [],
|
||||
"furniture": []
|
||||
}
|
||||
|
||||
environments = ["办公室", "室内", "室外", "会议室"]
|
||||
for env in environments:
|
||||
if env in answer.lower():
|
||||
info["environment"] = env
|
||||
break
|
||||
|
||||
people_patterns = [
|
||||
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
||||
r'(男|女)(性|生|士)',
|
||||
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
||||
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
||||
r'(群众|民众|大众|公众)',
|
||||
r'(男女|老少|老幼|大人|小孩)'
|
||||
]
|
||||
for pattern in people_patterns:
|
||||
match = re.search(pattern, answer)
|
||||
if match:
|
||||
if match.group(1).isdigit():
|
||||
info["num_people"] = int(match.group(1))
|
||||
elif match.group(1) in ['一个', '一']:
|
||||
info["num_people"] = 1
|
||||
else:
|
||||
num_word_to_digit = {
|
||||
'二': 2, '三': 3, '四': 4, '五': 5,
|
||||
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
||||
}
|
||||
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
||||
break
|
||||
|
||||
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
||||
for action in actions:
|
||||
if action in answer:
|
||||
info["actions"].append(action)
|
||||
|
||||
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
||||
for interaction in interactions:
|
||||
if interaction in answer:
|
||||
info["interactions"].append(interaction)
|
||||
|
||||
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
||||
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
||||
|
||||
for obj in objects:
|
||||
if obj in answer:
|
||||
info["objects"].append(obj)
|
||||
|
||||
for item in furniture:
|
||||
if item in answer:
|
||||
info["furniture"].append(item)
|
||||
|
||||
return info
|
||||
|
||||
# 初始化 MediaAnalysisSystem
|
||||
media_analysis_system = MediaAnalysisSystem(model, tokenizer)
|
||||
|
||||
async def process_file(file: UploadFile, file_type: str, api_key: str):
|
||||
content = await file.read()
|
||||
original_extension = os.path.splitext(file.filename)[1]
|
||||
|
||||
filename = f"cpm_{uuid.uuid4()}{original_extension}"
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 计算token
|
||||
tokens_required = calculate_tokens(file_path, file_type)
|
||||
|
||||
# 检查并更新token使用量
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新token使用量
|
||||
model_name = "MiniCPM-V-2_6"
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": filename,
|
||||
"type": file_type
|
||||
}).encode('utf-8'))
|
||||
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "文件已上传并排队等待处理",
|
||||
"filename": filename,
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
})
|
||||
|
||||
@cpm_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
try:
|
||||
file_type = "image" if file.content_type.startswith("image") else "video"
|
||||
return await process_file(file, file_type, api_key)
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
@cpm_app.post("/analyze_video")
|
||||
async def analyze_video(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
try:
|
||||
return await process_file(file, "video", api_key)
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
@cpm_app.post("/analyze_image")
|
||||
async def analyze_image(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
try:
|
||||
return await process_file(file, "image", api_key)
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
try:
|
||||
if isinstance(message.value, dict):
|
||||
task = message.value
|
||||
else:
|
||||
task = json.loads(message.value.decode('utf-8'))
|
||||
|
||||
filename = task['filename']
|
||||
file_type = task['type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
file_data = f.read()
|
||||
|
||||
if file_type == "video":
|
||||
result = media_analysis_system.process_video(file_data, filename)
|
||||
elif file_type == "image":
|
||||
result = media_analysis_system.process_image(file_data, filename)
|
||||
|
||||
# 保存结果到 JSON 文件
|
||||
result_file_path = os.path.join(RESULT_DIR, f"{filename}.json")
|
||||
with open(result_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 将结果存储在 Redis 中
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"status": "completed",
|
||||
"result": result
|
||||
}))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
if 'filename' in locals() and 'file_type' in locals():
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
else:
|
||||
print("Error occurred before task details were extracted")
|
||||
|
||||
@cpm_app.get("/result/{filename}")
|
||||
async def get_result(filename: str):
|
||||
for file_type in ["video", "image"]:
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_json = json.loads(result)
|
||||
|
||||
if result_json.get("status") == "queued":
|
||||
return {"status": "queued", "message": "Your request is in the queue and will be processed soon."}
|
||||
elif result_json.get("status") == "processing":
|
||||
return {"status": "processing", "message": "Your request is being processed."}
|
||||
else:
|
||||
return result_json
|
||||
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
async def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@5__:*_result:*')
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
print(f"Key changed: {key}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 在后台线程中启动Kafka消费者
|
||||
consumer_thread = threading.Thread(target=process_task, daemon=True)
|
||||
consumer_thread.start()
|
||||
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=7000)
|
||||
@@ -1,526 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, Security
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
from redis import Redis
|
||||
import io
|
||||
import re
|
||||
import torch
|
||||
from contextlib import asynccontextmanager
|
||||
import threading
|
||||
import string
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
cpm_app = FastAPI()
|
||||
app.mount("/cpm", cpm_app)
|
||||
|
||||
# CORS设置
|
||||
ALLOWED_ORIGINS = ['https://beta.obscura.work']
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/worker_sys/OpenBMB/MiniCPM-V-2_6"
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "cpm"
|
||||
KAFKA_GROUP_ID = "cpm_group"
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 5
|
||||
REDIS_API_DB = 12
|
||||
REDIS_API_USAGE_DB = 13
|
||||
|
||||
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
redis_api_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_DB
|
||||
)
|
||||
|
||||
redis_api_usage_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_USAGE_DB
|
||||
)
|
||||
|
||||
# 定义API密钥头部
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
# 定义 base62 字符集
|
||||
BASE62 = string.digits + string.ascii_letters
|
||||
|
||||
# 验证API密钥的函数
|
||||
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||
if api_key and api_key.startswith("Bearer "):
|
||||
key = api_key.split(" ")[1]
|
||||
if key.startswith("obs-"):
|
||||
return key
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的API密钥",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
"api_key": api_key,
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
# 更新总的token使用量
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
# 更新特定模型的token使用量
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||
base_tokens = 0
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||
|
||||
# 基础token:每MB文件大小消耗10个token
|
||||
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
||||
|
||||
if file_type == "image":
|
||||
img = Image.open(file_path)
|
||||
width, height = img.size
|
||||
pixel_count = width * height
|
||||
|
||||
# 图片token:每100个像素额外消耗5个token
|
||||
image_tokens = int((pixel_count / 10000) * 5)
|
||||
|
||||
base_tokens += image_tokens
|
||||
|
||||
elif file_type == "video":
|
||||
vr = VideoReader(file_path)
|
||||
fps = vr.get_avg_fps()
|
||||
frame_count = len(vr)
|
||||
width, height = vr[0].shape[1], vr[0].shape[0]
|
||||
|
||||
pixel_count = width * height * frame_count
|
||||
duration = frame_count / fps # 视频时长(秒)
|
||||
|
||||
# 视频token:每100万像素每秒额外消耗1个token
|
||||
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
||||
|
||||
base_tokens += video_tokens
|
||||
|
||||
return max(1, base_tokens) # 确保至少返回1个token
|
||||
except Exception as e:
|
||||
print(f"计算token时出错: {str(e)}")
|
||||
return 1 # 出错时返回默认值1
|
||||
|
||||
# 设置 GPU 设备
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# 初始化模型
|
||||
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
||||
model = model.half().cuda().eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
||||
|
||||
class MediaAnalysisSystem:
|
||||
def __init__(self, model, tokenizer):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = torch.device("cuda:0")
|
||||
self.model = self.model.to(self.device)
|
||||
self.MAX_NUM_FRAMES = 16
|
||||
|
||||
def encode_video(self, video_data):
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
return [l[int(i * gap + gap / 2)] for i in range(n)]
|
||||
|
||||
video_file = io.BytesIO(video_data)
|
||||
vr = VideoReader(video_file, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1)
|
||||
frame_idx = list(range(0, len(vr), sample_fps))
|
||||
if len(frame_idx) > self.MAX_NUM_FRAMES:
|
||||
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
||||
print('num frames:', len(frames))
|
||||
return frames
|
||||
|
||||
|
||||
def process_video(self, video_data, object_name):
|
||||
if not video_data:
|
||||
raise ValueError(f"Empty video data for {object_name}")
|
||||
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
|
||||
frames = self.encode_video(video_data)
|
||||
question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior."
|
||||
msgs = [
|
||||
{'role': 'user', 'content': frames + [question]},
|
||||
]
|
||||
|
||||
params = {
|
||||
"use_image_id": False,
|
||||
"max_slice_nums": 1
|
||||
}
|
||||
answer = self.model.chat(
|
||||
image=frames, # 直接传递 frames
|
||||
msgs=msgs,
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=512,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
**params
|
||||
)
|
||||
extracted_info = self.extract_info(answer)
|
||||
|
||||
return {
|
||||
"original_answer": answer,
|
||||
"extracted_info": extracted_info,
|
||||
"num_frames": len(frames),
|
||||
# "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
# "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
def process_image(self, image_data, object_name):
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
question = "描述这张图片,包括场景、人物数量和行为等细节。"
|
||||
msgs = [
|
||||
{'role': 'user', 'content': [image] + [question]},
|
||||
]
|
||||
|
||||
params = {
|
||||
"use_image_id": False,
|
||||
"max_slice_nums": 1
|
||||
}
|
||||
|
||||
answer = self.model.chat(
|
||||
image=None,
|
||||
msgs=msgs,
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=512,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
**params
|
||||
)
|
||||
|
||||
extracted_info = self.extract_info(answer)
|
||||
|
||||
return {
|
||||
"original_answer": answer,
|
||||
"extracted_info": extracted_info,
|
||||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def extract_time_from_filename(object_name):
|
||||
filename = os.path.basename(object_name)
|
||||
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
|
||||
|
||||
try:
|
||||
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
|
||||
end_time = start_time + timedelta(seconds=10)
|
||||
return start_time, end_time
|
||||
except ValueError:
|
||||
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
|
||||
return datetime.now(), datetime.now() + timedelta(seconds=10)
|
||||
|
||||
@staticmethod
|
||||
def extract_info(answer):
|
||||
info = {
|
||||
"environment": None,
|
||||
"num_people": None,
|
||||
"actions": [],
|
||||
"interactions": [],
|
||||
"objects": [],
|
||||
"furniture": []
|
||||
}
|
||||
|
||||
environments = ["办公室", "室内", "室外", "会议室"]
|
||||
for env in environments:
|
||||
if env in answer.lower():
|
||||
info["environment"] = env
|
||||
break
|
||||
|
||||
people_patterns = [
|
||||
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
||||
r'(男|女)(性|生|士)',
|
||||
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
||||
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
||||
r'(群众|民众|大众|公众)',
|
||||
r'(男女|老少|老幼|大人|小孩)'
|
||||
]
|
||||
for pattern in people_patterns:
|
||||
match = re.search(pattern, answer)
|
||||
if match:
|
||||
if match.group(1).isdigit():
|
||||
info["num_people"] = int(match.group(1))
|
||||
elif match.group(1) in ['一个', '一']:
|
||||
info["num_people"] = 1
|
||||
else:
|
||||
num_word_to_digit = {
|
||||
'二': 2, '三': 3, '四': 4, '五': 5,
|
||||
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
||||
}
|
||||
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
||||
break
|
||||
|
||||
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
||||
for action in actions:
|
||||
if action in answer:
|
||||
info["actions"].append(action)
|
||||
|
||||
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
||||
for interaction in interactions:
|
||||
if interaction in answer:
|
||||
info["interactions"].append(interaction)
|
||||
|
||||
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
||||
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
||||
|
||||
for obj in objects:
|
||||
if obj in answer:
|
||||
info["objects"].append(obj)
|
||||
|
||||
for item in furniture:
|
||||
if item in answer:
|
||||
info["furniture"].append(item)
|
||||
|
||||
return info
|
||||
|
||||
# 初始化 MediaAnalysisSystem
|
||||
media_analysis_system = MediaAnalysisSystem(model, tokenizer)
|
||||
|
||||
async def process_file(file: UploadFile, file_type: str, api_key_info: dict):
|
||||
content = await file.read()
|
||||
original_extension = os.path.splitext(file.filename)[1]
|
||||
|
||||
filename = f"cpm_{uuid.uuid4()}{original_extension}"
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 计算token
|
||||
tokens_required = calculate_tokens(file_path, file_type)
|
||||
|
||||
# 检查并更新token使用量
|
||||
api_key = api_key_info['api_key']
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新token使用量
|
||||
model_name = "MiniCPM-V-2_6"
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": filename,
|
||||
"type": file_type
|
||||
}).encode('utf-8'))
|
||||
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "文件已上传并排队等待处理",
|
||||
"filename": filename,
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
})
|
||||
|
||||
@cpm_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
try:
|
||||
file_type = "image" if file.content_type.startswith("image") else "video"
|
||||
return await process_file(file, file_type, api_key_info)
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
@cpm_app.post("/analyze_video")
|
||||
async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
api_key_info = await verify_api_key(api_key_info)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
try:
|
||||
return await process_file(file, "video", api_key_info)
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
@cpm_app.post("/analyze_image")
|
||||
async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
api_key_info = await verify_api_key(api_key_info)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
try:
|
||||
return await process_file(file, "image", api_key_info)
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
try:
|
||||
if isinstance(message.value, dict):
|
||||
task = message.value
|
||||
else:
|
||||
task = json.loads(message.value.decode('utf-8'))
|
||||
|
||||
filename = task['filename']
|
||||
file_type = task['type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
file_data = f.read()
|
||||
|
||||
if file_type == "video":
|
||||
result = media_analysis_system.process_video(file_data, filename)
|
||||
elif file_type == "image":
|
||||
result = media_analysis_system.process_image(file_data, filename)
|
||||
|
||||
# 保存结果到 JSON 文件
|
||||
result_file_path = os.path.join(RESULT_DIR, f"{filename}.json")
|
||||
with open(result_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 将结果存储在 Redis 中
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"status": "completed",
|
||||
"result": result
|
||||
}))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
if 'filename' in locals() and 'file_type' in locals():
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
else:
|
||||
print("Error occurred before task details were extracted")
|
||||
|
||||
@cpm_app.get("/result/{filename}")
|
||||
async def get_result(filename: str):
|
||||
for file_type in ["video", "image"]:
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_json = json.loads(result)
|
||||
|
||||
if result_json.get("status") == "queued":
|
||||
return {"status": "queued", "message": "Your request is in the queue and will be processed soon."}
|
||||
elif result_json.get("status") == "processing":
|
||||
return {"status": "processing", "message": "Your request is being processed."}
|
||||
else:
|
||||
return result_json
|
||||
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
async def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@5__:*_result:*')
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
print(f"Key changed: {key}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 在后台线程中启动Kafka消费者
|
||||
consumer_thread = threading.Thread(target=process_task, daemon=True)
|
||||
consumer_thread.start()
|
||||
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=7000)
|
||||
@@ -1,312 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from redis import Redis
|
||||
import io
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
import threading
|
||||
import torch
|
||||
torch.cuda.set_device(1)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
face_app = FastAPI()
|
||||
app.mount("/face", face_app)
|
||||
|
||||
# CORS配置
|
||||
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
||||
|
||||
# 只为主应用添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/models/yolov8n-face.pt" # 请替换为您的模型路径
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "face" # 指定Kafka topic
|
||||
KAFKA_GROUP_ID = "face_group" # 指定消费者组ID
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 7
|
||||
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
class faceDetector:
|
||||
def __init__(self, model_path):
|
||||
self.model = YOLO(model_path).to('cuda:1')
|
||||
|
||||
def detect(self, frame):
|
||||
results = self.model(frame, device='cuda:1')
|
||||
return results
|
||||
|
||||
def format_results(self, results):
|
||||
formatted_results = []
|
||||
for r in results:
|
||||
boxes = r.boxes
|
||||
keypoints = r.keypoints
|
||||
for i in range(len(boxes)):
|
||||
box = boxes[i]
|
||||
kpts = keypoints[i]
|
||||
formatted_results.append({
|
||||
"bbox": box.xyxy.tolist()[0],
|
||||
"confidence": box.conf.item(),
|
||||
"keypoints": kpts.xy.tolist()[0]
|
||||
})
|
||||
return formatted_results
|
||||
|
||||
def draw_results(self, frame, results, original_shape):
|
||||
for r in results:
|
||||
annotated_frame = r.plot(img=frame)
|
||||
# 调整坐标以适应原始图像大小
|
||||
h, w = annotated_frame.shape[:2]
|
||||
scale_x, scale_y = original_shape[1] / w, original_shape[0] / h
|
||||
annotated_frame = cv2.resize(annotated_frame, (original_shape[1], original_shape[0]))
|
||||
return annotated_frame
|
||||
|
||||
detector = faceDetector(MODEL_PATH)
|
||||
|
||||
def process_image(image_data, filename):
|
||||
try:
|
||||
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
original_shape = img.shape
|
||||
# Convert BGR to RGB
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize image to fit model requirements (640x640)
|
||||
img_resized = cv2.resize(img, (640, 640))
|
||||
|
||||
# Normalize and reshape to BCHW format
|
||||
img_tensor = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||
img_tensor = img_tensor.to('cuda:1')
|
||||
|
||||
results = detector.detect(img_tensor)
|
||||
|
||||
# Format results for JSON
|
||||
json_results = detector.format_results(results)
|
||||
|
||||
# Draw results on original image
|
||||
annotated_img = detector.draw_results(img_resized, results, original_shape)
|
||||
|
||||
# Save annotated image
|
||||
annotated_filename = f"face_{filename}"
|
||||
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
cv2.imwrite(annotated_path, cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR))
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def process_video(video_data, filename):
|
||||
try:
|
||||
temp_video_path = os.path.join(UPLOAD_DIR, f"face_{filename}")
|
||||
with open(temp_video_path, 'wb') as temp_video:
|
||||
temp_video.write(video_data)
|
||||
|
||||
cap = cv2.VideoCapture(temp_video_path)
|
||||
frame_count = 0
|
||||
json_results = []
|
||||
|
||||
# Get video properties
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
original_shape = (height, width)
|
||||
|
||||
# Create output video file
|
||||
annotated_filename = f"face_{filename}"
|
||||
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# Process one frame per second
|
||||
if frame_count % fps == 0:
|
||||
# Convert BGR to RGB
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize frame to fit model requirements (640x640)
|
||||
frame_resized = cv2.resize(frame_rgb, (640, 640))
|
||||
|
||||
# Normalize and reshape to BCHW format
|
||||
frame_tensor = torch.from_numpy(frame_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||
frame_tensor = frame_tensor.to('cuda:1')
|
||||
|
||||
results = detector.detect(frame_tensor)
|
||||
frame_json_results = detector.format_results(results)
|
||||
json_results.append({"frame": frame_count, "detections": frame_json_results})
|
||||
|
||||
# Draw results on original frame
|
||||
annotated_frame = detector.draw_results(frame_resized, results, original_shape)
|
||||
# Convert RGB back to BGR for OpenCV
|
||||
annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
annotated_frame = frame
|
||||
|
||||
out.write(annotated_frame)
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
# Clean up temporary input video file
|
||||
os.remove(temp_video_path)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing video: {str(e)}")
|
||||
return None, None
|
||||
|
||||
@face_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# Save the original file
|
||||
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(original_file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Send processing task to Kafka
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": new_filename,
|
||||
"file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
}).encode('utf-8'))
|
||||
|
||||
# Set initial status in Redis
|
||||
redis_key = f"face_result:{new_filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
return JSONResponse(content={"message": "File uploaded and queued for processing", "filename": new_filename})
|
||||
|
||||
@face_app.get("/result/{filename}")
|
||||
async def get_face_result(filename: str):
|
||||
redis_key = f"face_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
return JSONResponse(content=result_data) # 直接返回整个结果,包括 status
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
@face_app.get("/annotated/{filename}")
|
||||
async def get_annotated_file(filename: str):
|
||||
redis_key = f"face_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
if result_data["status"] == "completed":
|
||||
annotated_filename = result_data["annotated_filename"]
|
||||
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
if os.path.exists(file_path):
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
||||
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Annotated file not found")
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
task = message.value
|
||||
filename = task['filename']
|
||||
file_type = task['file_type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
# Update status to "processing"
|
||||
redis_key = f"face_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
try:
|
||||
if file_type == "image":
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_image(content, filename)
|
||||
else:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_video(content, filename)
|
||||
|
||||
if json_results and annotated_filename:
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"json_results": json_results,
|
||||
"status": "completed",
|
||||
"annotated_filename": annotated_filename
|
||||
}))
|
||||
else:
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
|
||||
def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@3__:face_result:*') # 监听所有face_result键的变化
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
operation = message['data'].decode('utf-8')
|
||||
|
||||
if operation == 'set':
|
||||
value = redis_client.get(f"face_result:{key}")
|
||||
if value:
|
||||
result = json.loads(value)
|
||||
print(f"Status update for {key}: {result['status']}")
|
||||
|
||||
# 这里可以添加其他处理逻辑,比如发送通知等
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动处理任务的线程
|
||||
threading.Thread(target=process_task, daemon=True).start()
|
||||
|
||||
# 启动Redis监听线程
|
||||
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=7004)
|
||||
@@ -1,462 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Header
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from redis import Redis
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import threading
|
||||
import torch
|
||||
torch.cuda.set_device(1)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
face_app = FastAPI()
|
||||
app.mount("/face", face_app)
|
||||
|
||||
# CORS配置
|
||||
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
||||
|
||||
# 只为主应用添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/models/yolov8n-face.pt" # 请替换为您的模型路径
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "face" # 指定Kafka topic
|
||||
KAFKA_GROUP_ID = "face_group" # 指定消费者组ID
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 7
|
||||
REDIS_API_DB = 12
|
||||
REDIS_API_USAGE_DB = 13
|
||||
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
redis_api_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_DB
|
||||
)
|
||||
redis_api_usage_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_USAGE_DB
|
||||
)
|
||||
|
||||
# 添加API密钥验证
|
||||
API_KEY_NAME = "X-API-Key"
|
||||
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
||||
|
||||
async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)):
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=400, detail="API密钥缺失")
|
||||
return api_key
|
||||
|
||||
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||
base_tokens = 0
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||
|
||||
# 基础token:每MB文件大小消耗10个token
|
||||
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
||||
|
||||
if file_type == "image":
|
||||
img = cv2.imread(file_path)
|
||||
if img is None:
|
||||
raise ValueError("无法读取图片文件")
|
||||
height, width = img.shape[:2]
|
||||
pixel_count = height * width
|
||||
|
||||
# 图片token:每100个像素额外消耗5个token
|
||||
image_tokens = int((pixel_count / 10000) * 5)
|
||||
|
||||
base_tokens += image_tokens
|
||||
|
||||
elif file_type == "video":
|
||||
cap = cv2.VideoCapture(file_path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError("无法打开视频文件")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
cap.release()
|
||||
|
||||
pixel_count = width * height * frame_count
|
||||
duration = frame_count / fps # 视频时长(秒)
|
||||
|
||||
# 视频token:每100万像素每秒额外消耗1个token
|
||||
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
||||
|
||||
base_tokens += video_tokens
|
||||
|
||||
return max(1, base_tokens) # 确保至少返回1个token
|
||||
except Exception as e:
|
||||
print(f"计算token时出错: {str(e)}")
|
||||
return 1 # 出错时返回默认值1
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
return None
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
return None
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
return None
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
# 更新总的token使用量
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
# 更新特定模型的token使用量
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
class faceDetector:
|
||||
def __init__(self, model_path):
|
||||
self.model = YOLO(model_path).to('cuda:1')
|
||||
|
||||
def detect(self, frame):
|
||||
results = self.model(frame, device='cuda:1')
|
||||
return results
|
||||
|
||||
def format_results(self, results):
|
||||
formatted_results = []
|
||||
for r in results:
|
||||
boxes = r.boxes
|
||||
keypoints = r.keypoints
|
||||
for i in range(len(boxes)):
|
||||
box = boxes[i]
|
||||
kpts = keypoints[i]
|
||||
formatted_results.append({
|
||||
"bbox": box.xyxy.tolist()[0],
|
||||
"confidence": box.conf.item(),
|
||||
"keypoints": kpts.xy.tolist()[0]
|
||||
})
|
||||
return formatted_results
|
||||
|
||||
def draw_results(self, frame, results, original_shape):
|
||||
for r in results:
|
||||
annotated_frame = r.plot(img=frame)
|
||||
# 调整坐标以适应原始图像大小
|
||||
h, w = annotated_frame.shape[:2]
|
||||
scale_x, scale_y = original_shape[1] / w, original_shape[0] / h
|
||||
annotated_frame = cv2.resize(annotated_frame, (original_shape[1], original_shape[0]))
|
||||
return annotated_frame
|
||||
|
||||
detector = faceDetector(MODEL_PATH)
|
||||
|
||||
def process_image(image_data, filename):
|
||||
try:
|
||||
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
original_shape = img.shape
|
||||
# Convert BGR to RGB
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize image to fit model requirements (640x640)
|
||||
img_resized = cv2.resize(img, (640, 640))
|
||||
|
||||
# Normalize and reshape to BCHW format
|
||||
img_tensor = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||
img_tensor = img_tensor.to('cuda:1')
|
||||
|
||||
results = detector.detect(img_tensor)
|
||||
|
||||
# Format results for JSON
|
||||
json_results = detector.format_results(results)
|
||||
|
||||
# Draw results on original image
|
||||
annotated_img = detector.draw_results(img_resized, results, original_shape)
|
||||
|
||||
# Save annotated image
|
||||
annotated_filename = f"face_{filename}"
|
||||
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
cv2.imwrite(annotated_path, cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR))
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def process_video(video_data, filename):
|
||||
try:
|
||||
temp_video_path = os.path.join(UPLOAD_DIR, f"face_{filename}")
|
||||
with open(temp_video_path, 'wb') as temp_video:
|
||||
temp_video.write(video_data)
|
||||
|
||||
cap = cv2.VideoCapture(temp_video_path)
|
||||
frame_count = 0
|
||||
json_results = []
|
||||
|
||||
# Get video properties
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
original_shape = (height, width)
|
||||
|
||||
# Create output video file
|
||||
annotated_filename = f"face_{filename}"
|
||||
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# Process one frame per second
|
||||
if frame_count % fps == 0:
|
||||
# Convert BGR to RGB
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize frame to fit model requirements (640x640)
|
||||
frame_resized = cv2.resize(frame_rgb, (640, 640))
|
||||
|
||||
# Normalize and reshape to BCHW format
|
||||
frame_tensor = torch.from_numpy(frame_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||
frame_tensor = frame_tensor.to('cuda:1')
|
||||
|
||||
results = detector.detect(frame_tensor)
|
||||
frame_json_results = detector.format_results(results)
|
||||
json_results.append({"frame": frame_count, "detections": frame_json_results})
|
||||
|
||||
# Draw results on original frame
|
||||
annotated_frame = detector.draw_results(frame_resized, results, original_shape)
|
||||
# Convert RGB back to BGR for OpenCV
|
||||
annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
annotated_frame = frame
|
||||
|
||||
out.write(annotated_frame)
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
# Clean up temporary input video file
|
||||
os.remove(temp_video_path)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing video: {str(e)}")
|
||||
return None, None
|
||||
|
||||
@face_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...),api_key: str = Depends(get_api_key)):
|
||||
# 验证 API key
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# Save the original file
|
||||
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(original_file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 计算 token
|
||||
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
tokens_required = calculate_tokens(original_file_path, file_type)
|
||||
if tokens_required is None or tokens_required <= 0:
|
||||
raise HTTPException(status_code=500, detail="无法计算所需的token数量")
|
||||
|
||||
|
||||
# 检查并更新 token 使用量
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
model_name = "yolov8n-face"
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
|
||||
# Send processing task to Kafka
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": new_filename,
|
||||
"file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
}).encode('utf-8'))
|
||||
|
||||
# Set initial status in Redis
|
||||
redis_key = f"face_result:{new_filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "文件已上传并排队等待处理",
|
||||
"filename": new_filename,
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
})
|
||||
|
||||
@face_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_face_result(filename: str):
|
||||
redis_key = f"face_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
return JSONResponse(content=result_data) # 直接返回整个结果,包括 status
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
@face_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_annotated_file(filename: str):
|
||||
redis_key = f"face_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
if result_data["status"] == "completed":
|
||||
annotated_filename = result_data["annotated_filename"]
|
||||
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
if os.path.exists(file_path):
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
||||
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Annotated file not found")
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
task = message.value
|
||||
filename = task['filename']
|
||||
file_type = task['file_type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
# Update status to "processing"
|
||||
redis_key = f"face_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
try:
|
||||
if file_type == "image":
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_image(content, filename)
|
||||
else:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_video(content, filename)
|
||||
|
||||
if json_results and annotated_filename:
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"json_results": json_results,
|
||||
"status": "completed",
|
||||
"annotated_filename": annotated_filename
|
||||
}))
|
||||
else:
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
|
||||
def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@3__:face_result:*') # 监听所有face_result键的变化
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
operation = message['data'].decode('utf-8')
|
||||
|
||||
if operation == 'set':
|
||||
value = redis_client.get(f"face_result:{key}")
|
||||
if value:
|
||||
result = json.loads(value)
|
||||
print(f"Status update for {key}: {result['status']}")
|
||||
|
||||
# 这里可以添加其他处理逻辑,比如发送通知等
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动处理任务的线程
|
||||
threading.Thread(target=process_task, daemon=True).start()
|
||||
|
||||
# 启动Redis监听线程
|
||||
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=7004)
|
||||
@@ -1,471 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from redis import Redis
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import threading
|
||||
import torch
|
||||
|
||||
import string
|
||||
torch.cuda.set_device(1)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
face_app = FastAPI()
|
||||
app.mount("/face", face_app)
|
||||
|
||||
# CORS配置
|
||||
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
||||
|
||||
# 只为主应用添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/models/yolov8n-face.pt" # 请替换为您的模型路径
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "face" # 指定Kafka topic
|
||||
KAFKA_GROUP_ID = "face_group" # 指定消费者组ID
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 7
|
||||
REDIS_API_DB = 12
|
||||
REDIS_API_USAGE_DB = 13
|
||||
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
redis_api_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_DB
|
||||
)
|
||||
redis_api_usage_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_USAGE_DB
|
||||
)
|
||||
|
||||
# 定义API密钥头部
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
# 定义 base62 字符集
|
||||
BASE62 = string.digits + string.ascii_letters
|
||||
|
||||
# 验证API密钥的函数
|
||||
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||
if api_key and api_key.startswith("Bearer "):
|
||||
key = api_key.split(" ")[1]
|
||||
if key.startswith("obs-"):
|
||||
return key
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的API密钥",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||
base_tokens = 0
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||
|
||||
# 基础token:每MB文件大小消耗10个token
|
||||
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
||||
|
||||
if file_type == "image":
|
||||
img = cv2.imread(file_path)
|
||||
if img is None:
|
||||
raise ValueError("无法读取图片文件")
|
||||
height, width = img.shape[:2]
|
||||
pixel_count = height * width
|
||||
|
||||
# 图片token:每100个像素额外消耗5个token
|
||||
image_tokens = int((pixel_count / 10000) * 5)
|
||||
|
||||
base_tokens += image_tokens
|
||||
|
||||
elif file_type == "video":
|
||||
cap = cv2.VideoCapture(file_path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError("无法打开视频文件")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
cap.release()
|
||||
|
||||
pixel_count = width * height * frame_count
|
||||
duration = frame_count / fps # 视频时长(秒)
|
||||
|
||||
# 视频token:每100万像素每秒额外消耗1个token
|
||||
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
||||
|
||||
base_tokens += video_tokens
|
||||
|
||||
return max(1, base_tokens) # 确保至少返回1个token
|
||||
except Exception as e:
|
||||
print(f"计算token时出错: {str(e)}")
|
||||
return 1 # 出错时返回默认值1
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
"api_key": api_key,
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
# 更新总的token使用量
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
# 更新特定模型的token使用量
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
class faceDetector:
|
||||
def __init__(self, model_path):
|
||||
self.model = YOLO(model_path).to('cuda:1')
|
||||
|
||||
def detect(self, frame):
|
||||
results = self.model(frame, device='cuda:1')
|
||||
return results
|
||||
|
||||
def format_results(self, results):
|
||||
formatted_results = []
|
||||
for r in results:
|
||||
boxes = r.boxes
|
||||
keypoints = r.keypoints
|
||||
for i in range(len(boxes)):
|
||||
box = boxes[i]
|
||||
kpts = keypoints[i]
|
||||
formatted_results.append({
|
||||
"bbox": box.xyxy.tolist()[0],
|
||||
"confidence": box.conf.item(),
|
||||
"keypoints": kpts.xy.tolist()[0]
|
||||
})
|
||||
return formatted_results
|
||||
|
||||
def draw_results(self, frame, results, original_shape):
|
||||
for r in results:
|
||||
annotated_frame = r.plot(img=frame)
|
||||
# 调整坐标以适应原始图像大小
|
||||
h, w = annotated_frame.shape[:2]
|
||||
scale_x, scale_y = original_shape[1] / w, original_shape[0] / h
|
||||
annotated_frame = cv2.resize(annotated_frame, (original_shape[1], original_shape[0]))
|
||||
return annotated_frame
|
||||
|
||||
detector = faceDetector(MODEL_PATH)
|
||||
|
||||
def process_image(image_data, filename):
|
||||
try:
|
||||
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
original_shape = img.shape
|
||||
# Convert BGR to RGB
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize image to fit model requirements (640x640)
|
||||
img_resized = cv2.resize(img, (640, 640))
|
||||
|
||||
# Normalize and reshape to BCHW format
|
||||
img_tensor = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||
img_tensor = img_tensor.to('cuda:1')
|
||||
|
||||
results = detector.detect(img_tensor)
|
||||
|
||||
# Format results for JSON
|
||||
json_results = detector.format_results(results)
|
||||
|
||||
# Draw results on original image
|
||||
annotated_img = detector.draw_results(img_resized, results, original_shape)
|
||||
|
||||
# Save annotated image
|
||||
annotated_filename = f"face_{filename}"
|
||||
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
cv2.imwrite(annotated_path, cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR))
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def process_video(video_data, filename):
|
||||
try:
|
||||
temp_video_path = os.path.join(UPLOAD_DIR, f"face_{filename}")
|
||||
with open(temp_video_path, 'wb') as temp_video:
|
||||
temp_video.write(video_data)
|
||||
|
||||
cap = cv2.VideoCapture(temp_video_path)
|
||||
frame_count = 0
|
||||
json_results = []
|
||||
|
||||
# Get video properties
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
original_shape = (height, width)
|
||||
|
||||
# Create output video file
|
||||
annotated_filename = f"face_{filename}"
|
||||
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# Process one frame per second
|
||||
if frame_count % fps == 0:
|
||||
# Convert BGR to RGB
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize frame to fit model requirements (640x640)
|
||||
frame_resized = cv2.resize(frame_rgb, (640, 640))
|
||||
|
||||
# Normalize and reshape to BCHW format
|
||||
frame_tensor = torch.from_numpy(frame_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||
frame_tensor = frame_tensor.to('cuda:1')
|
||||
|
||||
results = detector.detect(frame_tensor)
|
||||
frame_json_results = detector.format_results(results)
|
||||
json_results.append({"frame": frame_count, "detections": frame_json_results})
|
||||
|
||||
# Draw results on original frame
|
||||
annotated_frame = detector.draw_results(frame_resized, results, original_shape)
|
||||
# Convert RGB back to BGR for OpenCV
|
||||
annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
annotated_frame = frame
|
||||
|
||||
out.write(annotated_frame)
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
# Clean up temporary input video file
|
||||
os.remove(temp_video_path)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing video: {str(e)}")
|
||||
return None, None
|
||||
|
||||
@face_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# Save the original file
|
||||
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(original_file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 计算 token
|
||||
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
tokens_required = calculate_tokens(original_file_path, file_type)
|
||||
if tokens_required is None or tokens_required <= 0:
|
||||
raise HTTPException(status_code=500, detail="无法计算所需的token数量")
|
||||
|
||||
|
||||
# 检查并更新 token 使用量
|
||||
api_key = api_key_info['api_key']
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
model_name = "yolov8n-face"
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
|
||||
# Send processing task to Kafka
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": new_filename,
|
||||
"file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
}).encode('utf-8'))
|
||||
|
||||
# Set initial status in Redis
|
||||
redis_key = f"face_result:{new_filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "文件已上传并排队等待处理",
|
||||
"filename": new_filename,
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
})
|
||||
|
||||
@face_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_face_result(filename: str):
|
||||
redis_key = f"face_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
return JSONResponse(content=result_data) # 直接返回整个结果,包括 status
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
@face_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_annotated_file(filename: str):
|
||||
redis_key = f"face_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
if result_data["status"] == "completed":
|
||||
annotated_filename = result_data["annotated_filename"]
|
||||
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
if os.path.exists(file_path):
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
||||
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Annotated file not found")
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
task = message.value
|
||||
filename = task['filename']
|
||||
file_type = task['file_type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
# Update status to "processing"
|
||||
redis_key = f"face_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
try:
|
||||
if file_type == "image":
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_image(content, filename)
|
||||
else:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_video(content, filename)
|
||||
|
||||
if json_results and annotated_filename:
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"json_results": json_results,
|
||||
"status": "completed",
|
||||
"annotated_filename": annotated_filename
|
||||
}))
|
||||
else:
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
|
||||
def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@3__:face_result:*') # 监听所有face_result键的变化
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
operation = message['data'].decode('utf-8')
|
||||
|
||||
if operation == 'set':
|
||||
value = redis_client.get(f"face_result:{key}")
|
||||
if value:
|
||||
result = json.loads(value)
|
||||
print(f"Status update for {key}: {result['status']}")
|
||||
|
||||
# 这里可以添加其他处理逻辑,比如发送通知等
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动处理任务的线程
|
||||
threading.Thread(target=process_task, daemon=True).start()
|
||||
|
||||
# 启动Redis监听线程
|
||||
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=7004)
|
||||
@@ -1,293 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from redis import Redis
|
||||
import io
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
import threading
|
||||
|
||||
app = FastAPI()
|
||||
fall_app = FastAPI()
|
||||
app.mount("/fall", fall_app)
|
||||
|
||||
# CORS配置
|
||||
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
||||
|
||||
# 只为主应用添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/models/yolov8n-fall.pt" # 请替换为您的模型路径
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "fall" # 指定Kafka topic
|
||||
KAFKA_GROUP_ID = "fall_group" # 指定消费者组ID
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 4
|
||||
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
class fallDetector:
|
||||
def __init__(self, model_path):
|
||||
self.model = YOLO(model_path)
|
||||
def detect(self, frame):
|
||||
results = self.model(frame)
|
||||
return results
|
||||
def format_results(self, results):
|
||||
formatted_results = []
|
||||
for r in results:
|
||||
if not hasattr(r, 'boxes') or len(r.boxes) == 0:
|
||||
print("没有检测到任何对象")
|
||||
return [{"message": "No objects detected"}]
|
||||
|
||||
boxes = r.boxes
|
||||
names = getattr(r, 'names', {})
|
||||
|
||||
for i in range(len(boxes)):
|
||||
box = boxes[i]
|
||||
if not hasattr(box, 'cls') or not hasattr(box, 'conf') or not hasattr(box, 'xyxy'):
|
||||
print(f"警告: 第 {i} 个框缺少必要的属性")
|
||||
continue
|
||||
|
||||
try:
|
||||
class_id = int(box.cls.item())
|
||||
formatted_result = {
|
||||
"bbox": box.xyxy.tolist()[0],
|
||||
"confidence": box.conf.item(),
|
||||
"class_id": class_id,
|
||||
"class": names.get(class_id, f"Unknown-{class_id}")
|
||||
}
|
||||
formatted_results.append(formatted_result)
|
||||
except Exception as e:
|
||||
print(f"处理第 {i} 个框时出错: {str(e)}")
|
||||
|
||||
# print("格式化后的结果:", formatted_results)
|
||||
return formatted_results
|
||||
|
||||
def draw_results(self, frame, results):
|
||||
for r in results:
|
||||
annotated_frame = r.plot()
|
||||
return annotated_frame
|
||||
|
||||
detector = fallDetector(MODEL_PATH)
|
||||
|
||||
def process_image(image_data, filename):
|
||||
try:
|
||||
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
results = detector.detect(img)
|
||||
|
||||
# Format results for JSON
|
||||
json_results = detector.format_results(results)
|
||||
|
||||
# Draw results on image
|
||||
annotated_img = detector.draw_results(img, results)
|
||||
|
||||
# Save annotated image
|
||||
annotated_filename = f"fall_{filename}"
|
||||
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
cv2.imwrite(annotated_path, annotated_img)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {str(e)}")
|
||||
return None, None
|
||||
|
||||
|
||||
|
||||
def process_video(video_data, filename):
|
||||
try:
|
||||
temp_video_path = os.path.join(UPLOAD_DIR, f"fall_{filename}")
|
||||
with open(temp_video_path, 'wb') as temp_video:
|
||||
temp_video.write(video_data)
|
||||
|
||||
cap = cv2.VideoCapture(temp_video_path)
|
||||
frame_count = 0
|
||||
json_results = []
|
||||
|
||||
# Get video properties
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
# Create output video file
|
||||
annotated_filename = f"fall_{filename}"
|
||||
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
results = detector.detect(frame)
|
||||
frame_json_results = detector.format_results(results)
|
||||
json_results.append({"frame": frame_count, "detections": frame_json_results})
|
||||
|
||||
annotated_frame = detector.draw_results(frame, results)
|
||||
out.write(annotated_frame)
|
||||
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
# Clean up temporary input video file
|
||||
os.remove(temp_video_path)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing video: {str(e)}")
|
||||
return None, None
|
||||
|
||||
@fall_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# Save the original file
|
||||
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(original_file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Send processing task to Kafka
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": new_filename,
|
||||
"file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
}).encode('utf-8'))
|
||||
|
||||
# Set initial status in Redis
|
||||
redis_key = f"fall_result:{new_filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
return JSONResponse(content={"message": "File uploaded and queued for processing", "filename": new_filename})
|
||||
|
||||
@fall_app.get("/result/{filename}")
|
||||
async def get_fall_result(filename: str):
|
||||
redis_key = f"fall_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
return JSONResponse(content=result_data) # 直接返回整个结果,包括 status
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
@fall_app.get("/annotated/{filename}")
|
||||
async def get_annotated_file(filename: str):
|
||||
redis_key = f"fall_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
if result_data["status"] == "completed":
|
||||
annotated_filename = result_data["annotated_filename"]
|
||||
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
if os.path.exists(file_path):
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
||||
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Annotated file not found")
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
task = message.value
|
||||
filename = task['filename']
|
||||
file_type = task['file_type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
# Update status to "processing"
|
||||
redis_key = f"fall_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
try:
|
||||
if file_type == "image":
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_image(content, filename)
|
||||
else:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_video(content, filename)
|
||||
|
||||
if json_results and annotated_filename:
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"json_results": json_results,
|
||||
"status": "completed",
|
||||
"annotated_filename": annotated_filename
|
||||
}))
|
||||
else:
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
|
||||
def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@3__:fall_result:*') # 监听所有fall_result键的变化
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
operation = message['data'].decode('utf-8')
|
||||
|
||||
if operation == 'set':
|
||||
value = redis_client.get(f"fall_result:{key}")
|
||||
if value:
|
||||
result = json.loads(value)
|
||||
print(f"Status update for {key}: {result['status']}")
|
||||
|
||||
# 这里可以添加其他处理逻辑,比如发送通知等
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动处理任务的线程
|
||||
threading.Thread(target=process_task, daemon=True).start()
|
||||
|
||||
# 启动Redis监听线程
|
||||
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=7002)
|
||||
@@ -1,442 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Header
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from redis import Redis
|
||||
import io
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import threading
|
||||
|
||||
app = FastAPI()
|
||||
fall_app = FastAPI()
|
||||
app.mount("/fall", fall_app)
|
||||
|
||||
# CORS配置
|
||||
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
||||
|
||||
# 只为主应用添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/models/yolov8n-fall.pt" # 请替换为您的模型路径
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "fall" # 指定Kafka topic
|
||||
KAFKA_GROUP_ID = "fall_group" # 指定消费者组ID
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 4
|
||||
REDIS_API_DB = 12
|
||||
REDIS_API_USAGE_DB = 13
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
redis_api_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_DB
|
||||
)
|
||||
redis_api_usage_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_USAGE_DB
|
||||
)
|
||||
|
||||
# 添加API密钥验证
|
||||
API_KEY_NAME = "X-API-Key"
|
||||
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
||||
|
||||
async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)):
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=400, detail="API密钥缺失")
|
||||
return api_key
|
||||
|
||||
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||
base_tokens = 0
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||
|
||||
# 基础token:每MB文件大小消耗10个token
|
||||
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
||||
|
||||
if file_type == "image":
|
||||
img = cv2.imread(file_path)
|
||||
if img is None:
|
||||
raise ValueError("无法读取图片文件")
|
||||
height, width = img.shape[:2]
|
||||
pixel_count = height * width
|
||||
|
||||
# 图片token:每100个像素额外消耗5个token
|
||||
image_tokens = int((pixel_count / 10000) * 5)
|
||||
|
||||
base_tokens += image_tokens
|
||||
|
||||
elif file_type == "video":
|
||||
cap = cv2.VideoCapture(file_path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError("无法打开视频文件")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
cap.release()
|
||||
|
||||
pixel_count = width * height * frame_count
|
||||
duration = frame_count / fps # 视频时长(秒)
|
||||
|
||||
# 视频token:每100万像素每秒额外消耗1个token
|
||||
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
||||
|
||||
base_tokens += video_tokens
|
||||
|
||||
return max(1, base_tokens) # 确保至少返回1个token
|
||||
except Exception as e:
|
||||
print(f"计算token时出错: {str(e)}")
|
||||
return 1 # 出错时返回默认值1
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
return None
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
return None
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
return None
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
# 更新总的token使用量
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
# 更新特定模型的token使用量
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
class fallDetector:
|
||||
def __init__(self, model_path):
|
||||
self.model = YOLO(model_path)
|
||||
def detect(self, frame):
|
||||
results = self.model(frame)
|
||||
return results
|
||||
def format_results(self, results):
|
||||
formatted_results = []
|
||||
for r in results:
|
||||
if not hasattr(r, 'boxes') or len(r.boxes) == 0:
|
||||
print("没有检测到任何对象")
|
||||
return [{"message": "No objects detected"}]
|
||||
|
||||
boxes = r.boxes
|
||||
names = getattr(r, 'names', {})
|
||||
|
||||
for i in range(len(boxes)):
|
||||
box = boxes[i]
|
||||
if not hasattr(box, 'cls') or not hasattr(box, 'conf') or not hasattr(box, 'xyxy'):
|
||||
print(f"警告: 第 {i} 个框缺少必要的属性")
|
||||
continue
|
||||
|
||||
try:
|
||||
class_id = int(box.cls.item())
|
||||
formatted_result = {
|
||||
"bbox": box.xyxy.tolist()[0],
|
||||
"confidence": box.conf.item(),
|
||||
"class_id": class_id,
|
||||
"class": names.get(class_id, f"Unknown-{class_id}")
|
||||
}
|
||||
formatted_results.append(formatted_result)
|
||||
except Exception as e:
|
||||
print(f"处理第 {i} 个框时出错: {str(e)}")
|
||||
|
||||
# print("格式化后的结果:", formatted_results)
|
||||
return formatted_results
|
||||
|
||||
def draw_results(self, frame, results):
|
||||
for r in results:
|
||||
annotated_frame = r.plot()
|
||||
return annotated_frame
|
||||
|
||||
detector = fallDetector(MODEL_PATH)
|
||||
|
||||
def process_image(image_data, filename):
|
||||
try:
|
||||
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
results = detector.detect(img)
|
||||
|
||||
# Format results for JSON
|
||||
json_results = detector.format_results(results)
|
||||
|
||||
# Draw results on image
|
||||
annotated_img = detector.draw_results(img, results)
|
||||
|
||||
# Save annotated image
|
||||
annotated_filename = f"fall_{filename}"
|
||||
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
cv2.imwrite(annotated_path, annotated_img)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {str(e)}")
|
||||
return None, None
|
||||
|
||||
|
||||
|
||||
def process_video(video_data, filename):
|
||||
try:
|
||||
temp_video_path = os.path.join(UPLOAD_DIR, f"fall_{filename}")
|
||||
with open(temp_video_path, 'wb') as temp_video:
|
||||
temp_video.write(video_data)
|
||||
|
||||
cap = cv2.VideoCapture(temp_video_path)
|
||||
frame_count = 0
|
||||
json_results = []
|
||||
|
||||
# Get video properties
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
# Create output video file
|
||||
annotated_filename = f"fall_{filename}"
|
||||
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
results = detector.detect(frame)
|
||||
frame_json_results = detector.format_results(results)
|
||||
json_results.append({"frame": frame_count, "detections": frame_json_results})
|
||||
|
||||
annotated_frame = detector.draw_results(frame, results)
|
||||
out.write(annotated_frame)
|
||||
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
# Clean up temporary input video file
|
||||
os.remove(temp_video_path)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing video: {str(e)}")
|
||||
return None, None
|
||||
|
||||
@fall_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...),api_key: str = Depends(get_api_key)):
|
||||
# 验证 API key
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# Save the original file
|
||||
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(original_file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
# 计算 token
|
||||
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
tokens_required = calculate_tokens(original_file_path, file_type)
|
||||
if tokens_required is None or tokens_required <= 0:
|
||||
raise HTTPException(status_code=500, detail="无法计算所需的token数量")
|
||||
# 检查并更新 token 使用量
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
model_name = "yolov8n-fall"
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
|
||||
# Send processing task to Kafka
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": new_filename,
|
||||
"file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
}).encode('utf-8'))
|
||||
|
||||
# Set initial status in Redis
|
||||
redis_key = f"fall_result:{new_filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "文件已上传并排队等待处理",
|
||||
"filename": new_filename,
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
})
|
||||
|
||||
|
||||
@fall_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_fall_result(filename: str):
|
||||
redis_key = f"fall_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
return JSONResponse(content=result_data) # 直接返回整个结果,包括 status
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
@fall_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_annotated_file(filename: str):
|
||||
redis_key = f"fall_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
if result_data["status"] == "completed":
|
||||
annotated_filename = result_data["annotated_filename"]
|
||||
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
if os.path.exists(file_path):
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
||||
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Annotated file not found")
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
task = message.value
|
||||
filename = task['filename']
|
||||
file_type = task['file_type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
# Update status to "processing"
|
||||
redis_key = f"fall_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
try:
|
||||
if file_type == "image":
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_image(content, filename)
|
||||
else:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_video(content, filename)
|
||||
|
||||
if json_results and annotated_filename:
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"json_results": json_results,
|
||||
"status": "completed",
|
||||
"annotated_filename": annotated_filename
|
||||
}))
|
||||
else:
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
|
||||
def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@3__:fall_result:*') # 监听所有fall_result键的变化
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
operation = message['data'].decode('utf-8')
|
||||
|
||||
if operation == 'set':
|
||||
value = redis_client.get(f"fall_result:{key}")
|
||||
if value:
|
||||
result = json.loads(value)
|
||||
print(f"Status update for {key}: {result['status']}")
|
||||
|
||||
# 这里可以添加其他处理逻辑,比如发送通知等
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动处理任务的线程
|
||||
threading.Thread(target=process_task, daemon=True).start()
|
||||
|
||||
# 启动Redis监听线程
|
||||
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=7002)
|
||||
@@ -1,449 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from redis import Redis
|
||||
import io
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import threading
|
||||
import string
|
||||
app = FastAPI()
|
||||
fall_app = FastAPI()
|
||||
app.mount("/fall", fall_app)
|
||||
|
||||
# CORS配置
|
||||
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
||||
|
||||
# 只为主应用添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/models/yolov8n-fall.pt" # 请替换为您的模型路径
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "fall" # 指定Kafka topic
|
||||
KAFKA_GROUP_ID = "fall_group" # 指定消费者组ID
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 4
|
||||
REDIS_API_DB = 12
|
||||
REDIS_API_USAGE_DB = 13
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
redis_api_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_DB
|
||||
)
|
||||
redis_api_usage_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_USAGE_DB
|
||||
)
|
||||
|
||||
# 定义API密钥头部
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
# 定义 base62 字符集
|
||||
BASE62 = string.digits + string.ascii_letters
|
||||
|
||||
# 验证API密钥的函数
|
||||
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||
if api_key and api_key.startswith("Bearer "):
|
||||
key = api_key.split(" ")[1]
|
||||
if key.startswith("obs-"):
|
||||
return key
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的API密钥",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||
base_tokens = 0
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||
|
||||
# 基础token:每MB文件大小消耗10个token
|
||||
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
||||
|
||||
if file_type == "image":
|
||||
img = cv2.imread(file_path)
|
||||
if img is None:
|
||||
raise ValueError("无法读取图片文件")
|
||||
height, width = img.shape[:2]
|
||||
pixel_count = height * width
|
||||
|
||||
# 图片token:每100个像素额外消耗5个token
|
||||
image_tokens = int((pixel_count / 10000) * 5)
|
||||
|
||||
base_tokens += image_tokens
|
||||
|
||||
elif file_type == "video":
|
||||
cap = cv2.VideoCapture(file_path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError("无法打开视频文件")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
cap.release()
|
||||
|
||||
pixel_count = width * height * frame_count
|
||||
duration = frame_count / fps # 视频时长(秒)
|
||||
|
||||
# 视频token:每100万像素每秒额外消耗1个token
|
||||
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
||||
|
||||
base_tokens += video_tokens
|
||||
|
||||
return max(1, base_tokens) # 确保至少返回1个token
|
||||
except Exception as e:
|
||||
print(f"计算token时出错: {str(e)}")
|
||||
return 1 # 出错时返回默认值1
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
"api_key": api_key,
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
# 更新总的token使用量
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
# 更新特定模型的token使用量
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
class fallDetector:
|
||||
def __init__(self, model_path):
|
||||
self.model = YOLO(model_path)
|
||||
def detect(self, frame):
|
||||
results = self.model(frame)
|
||||
return results
|
||||
def format_results(self, results):
|
||||
formatted_results = []
|
||||
for r in results:
|
||||
if not hasattr(r, 'boxes') or len(r.boxes) == 0:
|
||||
print("没有检测到任何对象")
|
||||
return [{"message": "No objects detected"}]
|
||||
|
||||
boxes = r.boxes
|
||||
names = getattr(r, 'names', {})
|
||||
|
||||
for i in range(len(boxes)):
|
||||
box = boxes[i]
|
||||
if not hasattr(box, 'cls') or not hasattr(box, 'conf') or not hasattr(box, 'xyxy'):
|
||||
print(f"警告: 第 {i} 个框缺少必要的属性")
|
||||
continue
|
||||
|
||||
try:
|
||||
class_id = int(box.cls.item())
|
||||
formatted_result = {
|
||||
"bbox": box.xyxy.tolist()[0],
|
||||
"confidence": box.conf.item(),
|
||||
"class_id": class_id,
|
||||
"class": names.get(class_id, f"Unknown-{class_id}")
|
||||
}
|
||||
formatted_results.append(formatted_result)
|
||||
except Exception as e:
|
||||
print(f"处理第 {i} 个框时出错: {str(e)}")
|
||||
|
||||
# print("格式化后的结果:", formatted_results)
|
||||
return formatted_results
|
||||
|
||||
def draw_results(self, frame, results):
|
||||
for r in results:
|
||||
annotated_frame = r.plot()
|
||||
return annotated_frame
|
||||
|
||||
detector = fallDetector(MODEL_PATH)
|
||||
|
||||
def process_image(image_data, filename):
|
||||
try:
|
||||
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
results = detector.detect(img)
|
||||
|
||||
# Format results for JSON
|
||||
json_results = detector.format_results(results)
|
||||
|
||||
# Draw results on image
|
||||
annotated_img = detector.draw_results(img, results)
|
||||
|
||||
# Save annotated image
|
||||
annotated_filename = f"fall_{filename}"
|
||||
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
cv2.imwrite(annotated_path, annotated_img)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {str(e)}")
|
||||
return None, None
|
||||
|
||||
|
||||
|
||||
def process_video(video_data, filename):
|
||||
try:
|
||||
temp_video_path = os.path.join(UPLOAD_DIR, f"fall_{filename}")
|
||||
with open(temp_video_path, 'wb') as temp_video:
|
||||
temp_video.write(video_data)
|
||||
|
||||
cap = cv2.VideoCapture(temp_video_path)
|
||||
frame_count = 0
|
||||
json_results = []
|
||||
|
||||
# Get video properties
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
# Create output video file
|
||||
annotated_filename = f"fall_{filename}"
|
||||
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
results = detector.detect(frame)
|
||||
frame_json_results = detector.format_results(results)
|
||||
json_results.append({"frame": frame_count, "detections": frame_json_results})
|
||||
|
||||
annotated_frame = detector.draw_results(frame, results)
|
||||
out.write(annotated_frame)
|
||||
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
# Clean up temporary input video file
|
||||
os.remove(temp_video_path)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing video: {str(e)}")
|
||||
return None, None
|
||||
|
||||
@fall_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# Save the original file
|
||||
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(original_file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 计算 token
|
||||
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
tokens_required = calculate_tokens(original_file_path, file_type)
|
||||
if tokens_required is None or tokens_required <= 0:
|
||||
raise HTTPException(status_code=500, detail="无法计算所需的token数量")
|
||||
# 检查并更新 token 使用量
|
||||
api_key = api_key_info['api_key']
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
model_name = "yolov8n-fall"
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
|
||||
# Send processing task to Kafka
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": new_filename,
|
||||
"file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
}).encode('utf-8'))
|
||||
|
||||
# Set initial status in Redis
|
||||
redis_key = f"fall_result:{new_filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "文件已上传并排队等待处理",
|
||||
"filename": new_filename,
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
})
|
||||
|
||||
|
||||
@fall_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_fall_result(filename: str):
|
||||
redis_key = f"fall_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
return JSONResponse(content=result_data) # 直接返回整个结果,包括 status
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
@fall_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_annotated_file(filename: str):
|
||||
redis_key = f"fall_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
if result_data["status"] == "completed":
|
||||
annotated_filename = result_data["annotated_filename"]
|
||||
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
if os.path.exists(file_path):
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
||||
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Annotated file not found")
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
task = message.value
|
||||
filename = task['filename']
|
||||
file_type = task['file_type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
# Update status to "processing"
|
||||
redis_key = f"fall_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
try:
|
||||
if file_type == "image":
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_image(content, filename)
|
||||
else:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_video(content, filename)
|
||||
|
||||
if json_results and annotated_filename:
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"json_results": json_results,
|
||||
"status": "completed",
|
||||
"annotated_filename": annotated_filename
|
||||
}))
|
||||
else:
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
|
||||
def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@3__:fall_result:*') # 监听所有fall_result键的变化
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
operation = message['data'].decode('utf-8')
|
||||
|
||||
if operation == 'set':
|
||||
value = redis_client.get(f"fall_result:{key}")
|
||||
if value:
|
||||
result = json.loads(value)
|
||||
print(f"Status update for {key}: {result['status']}")
|
||||
|
||||
# 这里可以添加其他处理逻辑,比如发送通知等
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动处理任务的线程
|
||||
threading.Thread(target=process_task, daemon=True).start()
|
||||
|
||||
# 启动Redis监听线程
|
||||
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=7002)
|
||||
@@ -1,297 +0,0 @@
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from redis import Redis
|
||||
import json
|
||||
from kafka import KafkaConsumer
|
||||
import threading
|
||||
import redis
|
||||
import torch
|
||||
import media as mp
|
||||
from mediapipe.tasks import python
|
||||
from mediapipe.tasks.python import vision
|
||||
|
||||
# Configuration
|
||||
MODEL_PATH = "/home/zydi/models/face_landmarker.task" # Replace with your model path
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "mediapipe"
|
||||
KAFKA_GROUP_ID = "mediapipe_group"
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 9 # POSE Worker使用的Redis DB
|
||||
MAIN_REDIS_DB = 15 # 主Redis DB
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
|
||||
# Ensure directories exist
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# Initialize Kafka
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
main_redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=MAIN_REDIS_DB
|
||||
)
|
||||
|
||||
|
||||
class mediapipeEmbedder:
|
||||
def __init__(self, model_path):
|
||||
base_options = python.BaseOptions(model_asset_path=model_path)
|
||||
options = vision.FaceLandmarkerOptions(
|
||||
base_options=base_options,
|
||||
output_face_blendshapes=True,
|
||||
output_facial_transformation_matrixes=True,
|
||||
num_faces=1
|
||||
)
|
||||
self.detector = vision.FaceLandmarker.create_from_options(options)
|
||||
|
||||
def get_mediapipe_landmarks(self, image):
|
||||
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
|
||||
detection_result = self.detector.detect(mp_image)
|
||||
if detection_result.face_landmarks:
|
||||
return np.array([(lm.x, lm.y, lm.z) for lm in detection_result.face_landmarks[0]])
|
||||
return None
|
||||
|
||||
def process_image(self, image_data):
|
||||
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
landmarks = self.get_mediapipe_landmarks(img)
|
||||
|
||||
if landmarks is not None:
|
||||
# Calculate a more detailed mediapipe embedding
|
||||
embedding = self.calculate_detailed_embedding(landmarks)
|
||||
|
||||
# Draw landmarks on the image
|
||||
for lm in landmarks:
|
||||
cv2.circle(img, (int(lm[0]*img.shape[1]), int(lm[1]*img.shape[0])), 2, (0,255,0), -1)
|
||||
|
||||
return {
|
||||
"embedding": embedding,
|
||||
"landmarks": landmarks.tolist()
|
||||
}, img
|
||||
else:
|
||||
return None, img
|
||||
|
||||
def calculate_detailed_embedding(self, landmarks):
|
||||
# Calculate various statistical features
|
||||
mean = np.mean(landmarks, axis=0)
|
||||
std = np.std(landmarks, axis=0)
|
||||
median = np.median(landmarks, axis=0)
|
||||
min_vals = np.min(landmarks, axis=0)
|
||||
max_vals = np.max(landmarks, axis=0)
|
||||
|
||||
# Calculate pairwise distances between key facial landmarks
|
||||
nose_tip = landmarks[4]
|
||||
left_eye = landmarks[159]
|
||||
right_eye = landmarks[386]
|
||||
left_mouth = landmarks[61]
|
||||
right_mouth = landmarks[291]
|
||||
|
||||
eye_distance = np.linalg.norm(left_eye - right_eye)
|
||||
mouth_width = np.linalg.norm(left_mouth - right_mouth)
|
||||
nose_to_mouth = np.linalg.norm(nose_tip - (left_mouth + right_mouth) / 2)
|
||||
|
||||
# Calculate face shape features
|
||||
face_width = np.max(landmarks[:, 0]) - np.min(landmarks[:, 0])
|
||||
face_height = np.max(landmarks[:, 1]) - np.min(landmarks[:, 1])
|
||||
face_depth = np.max(landmarks[:, 2]) - np.min(landmarks[:, 2])
|
||||
|
||||
# Combine all features into a single embedding
|
||||
embedding = np.concatenate([
|
||||
mean, std, median, min_vals, max_vals,
|
||||
[eye_distance, mouth_width, nose_to_mouth, face_width, face_height, face_depth]
|
||||
])
|
||||
|
||||
return embedding.tolist()
|
||||
|
||||
embedder = mediapipeEmbedder(MODEL_PATH)
|
||||
|
||||
def process_image(image_data, filename):
|
||||
try:
|
||||
results, annotated_img = embedder.process_image(image_data)
|
||||
|
||||
if results:
|
||||
# Save annotated image
|
||||
annotated_filename = f"mediapipe_{filename}"
|
||||
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
cv2.imwrite(annotated_path, annotated_img)
|
||||
|
||||
return results, annotated_filename
|
||||
else:
|
||||
print(f"No face landmarks detected in image: {filename}")
|
||||
return None, None
|
||||
except Exception as e:
|
||||
print(f"Error processing image {filename}: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None, None
|
||||
|
||||
def process_video(video_data, filename):
|
||||
try:
|
||||
temp_video_path = os.path.join(UPLOAD_DIR, f"mediapipe_{filename}")
|
||||
with open(temp_video_path, 'wb') as temp_video:
|
||||
temp_video.write(video_data)
|
||||
|
||||
cap = cv2.VideoCapture(temp_video_path)
|
||||
frame_count = 0
|
||||
results = []
|
||||
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
annotated_filename = f"mediapipe_{filename}"
|
||||
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
if frame_count % fps == 0:
|
||||
frame_results, annotated_frame = embedder.process_image(cv2.imencode('.jpg', frame)[1].tobytes())
|
||||
if frame_results:
|
||||
results.append({"frame": frame_count, "results": frame_results})
|
||||
else:
|
||||
annotated_frame = frame
|
||||
|
||||
out.write(annotated_frame)
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
os.remove(temp_video_path)
|
||||
|
||||
return results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing video: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def process_task():
|
||||
print("开始处理任务,等待Kafka消息...")
|
||||
for message in consumer:
|
||||
print(f"收到Kafka消息: topic={message.topic}, partition={message.partition}, offset={message.offset}")
|
||||
task = message.value
|
||||
task_id = task['task_id']
|
||||
filename = task['filename']
|
||||
file_type = task['file_type']
|
||||
|
||||
print(f"解析任务信息: ID={task_id}, 文件名={filename}, 类型={file_type}")
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
# 检查键类型并更新状态
|
||||
task_key = f"task:{task_id}"
|
||||
try:
|
||||
key_type = main_redis_client.type(task_key)
|
||||
if key_type != b'hash':
|
||||
main_redis_client.delete(task_key)
|
||||
main_redis_client.hset(f"task:{task_id}", "status", "processing")
|
||||
print(f"任务 {task_id} 状态更新为 'processing'")
|
||||
except redis.exceptions.ResponseError as e:
|
||||
print(f"更新任务 {task_id} 状态时出错: {str(e)}")
|
||||
continue # 跳过这个任务,继续处理下一个
|
||||
|
||||
try:
|
||||
if file_type == "image":
|
||||
print(f"开始处理图像: {filename}")
|
||||
with open(file_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
json_results, annotated_filename = process_image(image_data, filename)
|
||||
if json_results and annotated_filename is not None:
|
||||
result_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
|
||||
redis_client.hset(f"fall_result:{task_id}", mapping={
|
||||
"result": json.dumps(json_results),
|
||||
"result_file": annotated_filename
|
||||
})
|
||||
main_redis_client.hset(f"task:{task_id}", "status", "completed")
|
||||
main_redis_client.hset(f"task:{task_id}", "result_type", "fall")
|
||||
main_redis_client.hset(f"task:{task_id}", "result_key", f"fall_result:{task_id}")
|
||||
print(f"图像 {filename} 处理完成,结果已保存")
|
||||
else:
|
||||
print(f"图像 {filename} 处理失败")
|
||||
main_redis_client.hset(f"task:{task_id}", "status", "failed")
|
||||
else: # video
|
||||
print(f"开始处理视频: {filename}")
|
||||
with open(file_path, 'rb') as f:
|
||||
video_data = f.read()
|
||||
json_results, annotated_filename = process_video(video_data, filename)
|
||||
if json_results and annotated_filename:
|
||||
redis_client.hset(f"fall_result:{task_id}", mapping={
|
||||
"result": json.dumps(json_results),
|
||||
"result_file": annotated_filename
|
||||
})
|
||||
main_redis_client.hset(f"task:{task_id}", "status", "completed")
|
||||
main_redis_client.hset(f"task:{task_id}", "result_type", "fall")
|
||||
main_redis_client.hset(f"task:{task_id}", "result_key", f"fall_result:{task_id}")
|
||||
|
||||
|
||||
|
||||
print(f"视频 {filename} 处理完成,结果已保存")
|
||||
else:
|
||||
print(f"视频 {filename} 处理失败")
|
||||
main_redis_client.hset(f"task:{task_id}", "status", "failed")
|
||||
except Exception as e:
|
||||
print(f"处理任务 {task_id} 时出错: {str(e)}")
|
||||
main_redis_client.hset(f"task:{task_id}", {
|
||||
"status": "failed",
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
print(f"任务 {task_id} 处理完毕,等待下一个Kafka消息...")
|
||||
|
||||
def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@3__:fall_result:*') # 监听所有fall_result键的变化
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
operation = message['data'].decode('utf-8')
|
||||
|
||||
if operation == 'hset':
|
||||
value = redis_client.hgetall(f"fall_result:{key}")
|
||||
if value:
|
||||
result = {k.decode(): v.decode() for k, v in value.items()}
|
||||
print(f"Result update for task {key}: {result}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("fall处理程序启动...")
|
||||
# 启动处理任务的线程
|
||||
task_thread = threading.Thread(target=process_task, daemon=True)
|
||||
task_thread.start()
|
||||
print("任务处理线程已启动")
|
||||
|
||||
# 启动Redis监听线程
|
||||
redis_thread = threading.Thread(target=listen_redis_changes, daemon=True)
|
||||
redis_thread.start()
|
||||
print("Redis监听线程已启动")
|
||||
|
||||
print("主程序进入等待状态...")
|
||||
# 保持主线程运行
|
||||
task_thread.join()
|
||||
redis_thread.join()
|
||||
@@ -1,304 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from redis import Redis
|
||||
import uuid
|
||||
import os
|
||||
from datetime import timedelta
|
||||
import threading
|
||||
import mediapipe as mp
|
||||
from mediapipe.tasks import python
|
||||
from mediapipe.tasks.python import vision
|
||||
|
||||
app = FastAPI()
|
||||
mediapipe_app = FastAPI()
|
||||
app.mount("/mediapipe", mediapipe_app)
|
||||
|
||||
# CORS configuration
|
||||
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Configuration
|
||||
MODEL_PATH = "/home/zydi/models/face_landmarker.task" # Replace with your model path
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "mediapipe"
|
||||
KAFKA_GROUP_ID = "mediapipe_group"
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 10
|
||||
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# Ensure directories exist
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# Initialize Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# Initialize Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
class mediapipeEmbedder:
|
||||
def __init__(self, model_path):
|
||||
base_options = python.BaseOptions(model_asset_path=model_path)
|
||||
options = vision.FaceLandmarkerOptions(
|
||||
base_options=base_options,
|
||||
output_face_blendshapes=True,
|
||||
output_facial_transformation_matrixes=True,
|
||||
num_faces=1
|
||||
)
|
||||
self.detector = vision.FaceLandmarker.create_from_options(options)
|
||||
|
||||
def get_mediapipe_landmarks(self, image):
|
||||
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
|
||||
detection_result = self.detector.detect(mp_image)
|
||||
if detection_result.face_landmarks:
|
||||
return np.array([(lm.x, lm.y, lm.z) for lm in detection_result.face_landmarks[0]])
|
||||
return None
|
||||
|
||||
def process_image(self, image_data):
|
||||
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
landmarks = self.get_mediapipe_landmarks(img)
|
||||
|
||||
if landmarks is not None:
|
||||
# Calculate a more detailed mediapipe embedding
|
||||
embedding = self.calculate_detailed_embedding(landmarks)
|
||||
|
||||
# Draw landmarks on the image
|
||||
for lm in landmarks:
|
||||
cv2.circle(img, (int(lm[0]*img.shape[1]), int(lm[1]*img.shape[0])), 2, (0,255,0), -1)
|
||||
|
||||
return {
|
||||
"embedding": embedding,
|
||||
"landmarks": landmarks.tolist()
|
||||
}, img
|
||||
else:
|
||||
return None, img
|
||||
|
||||
def calculate_detailed_embedding(self, landmarks):
|
||||
# Calculate various statistical features
|
||||
mean = np.mean(landmarks, axis=0)
|
||||
std = np.std(landmarks, axis=0)
|
||||
median = np.median(landmarks, axis=0)
|
||||
min_vals = np.min(landmarks, axis=0)
|
||||
max_vals = np.max(landmarks, axis=0)
|
||||
|
||||
# Calculate pairwise distances between key facial landmarks
|
||||
nose_tip = landmarks[4]
|
||||
left_eye = landmarks[159]
|
||||
right_eye = landmarks[386]
|
||||
left_mouth = landmarks[61]
|
||||
right_mouth = landmarks[291]
|
||||
|
||||
eye_distance = np.linalg.norm(left_eye - right_eye)
|
||||
mouth_width = np.linalg.norm(left_mouth - right_mouth)
|
||||
nose_to_mouth = np.linalg.norm(nose_tip - (left_mouth + right_mouth) / 2)
|
||||
|
||||
# Calculate face shape features
|
||||
face_width = np.max(landmarks[:, 0]) - np.min(landmarks[:, 0])
|
||||
face_height = np.max(landmarks[:, 1]) - np.min(landmarks[:, 1])
|
||||
face_depth = np.max(landmarks[:, 2]) - np.min(landmarks[:, 2])
|
||||
|
||||
# Combine all features into a single embedding
|
||||
embedding = np.concatenate([
|
||||
mean, std, median, min_vals, max_vals,
|
||||
[eye_distance, mouth_width, nose_to_mouth, face_width, face_height, face_depth]
|
||||
])
|
||||
|
||||
return embedding.tolist()
|
||||
|
||||
embedder = mediapipeEmbedder(MODEL_PATH)
|
||||
|
||||
def process_image(image_data, filename):
|
||||
try:
|
||||
results, annotated_img = embedder.process_image(image_data)
|
||||
|
||||
if results:
|
||||
# Save annotated image
|
||||
annotated_filename = f"mediapipe_{filename}"
|
||||
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
cv2.imwrite(annotated_path, annotated_img)
|
||||
|
||||
return results, annotated_filename
|
||||
else:
|
||||
print(f"No face landmarks detected in image: {filename}")
|
||||
return None, None
|
||||
except Exception as e:
|
||||
print(f"Error processing image {filename}: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None, None
|
||||
|
||||
def process_video(video_data, filename):
|
||||
try:
|
||||
temp_video_path = os.path.join(UPLOAD_DIR, f"mediapipe_{filename}")
|
||||
with open(temp_video_path, 'wb') as temp_video:
|
||||
temp_video.write(video_data)
|
||||
|
||||
cap = cv2.VideoCapture(temp_video_path)
|
||||
frame_count = 0
|
||||
results = []
|
||||
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
annotated_filename = f"mediapipe_{filename}"
|
||||
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
if frame_count % fps == 0:
|
||||
frame_results, annotated_frame = embedder.process_image(cv2.imencode('.jpg', frame)[1].tobytes())
|
||||
if frame_results:
|
||||
results.append({"frame": frame_count, "results": frame_results})
|
||||
else:
|
||||
annotated_frame = frame
|
||||
|
||||
out.write(annotated_frame)
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
os.remove(temp_video_path)
|
||||
|
||||
return results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing video: {str(e)}")
|
||||
return None, None
|
||||
|
||||
@mediapipe_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(original_file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": new_filename,
|
||||
"file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
}).encode('utf-8'))
|
||||
|
||||
redis_key = f"mediapipe_result:{new_filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
return JSONResponse(content={"message": "File uploaded and queued for processing", "filename": new_filename})
|
||||
|
||||
@mediapipe_app.get("/result/{filename}")
|
||||
async def get_mediapipe_result(filename: str):
|
||||
redis_key = f"mediapipe_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
return JSONResponse(content=result_data)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
@mediapipe_app.get("/annotated/{filename}")
|
||||
async def get_annotated_file(filename: str):
|
||||
redis_key = f"mediapipe_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
if result_data["status"] == "completed":
|
||||
annotated_filename = result_data["annotated_filename"]
|
||||
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
if os.path.exists(file_path):
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
||||
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Annotated file not found")
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
task = message.value
|
||||
filename = task['filename']
|
||||
file_type = task['file_type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
redis_key = f"mediapipe_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
try:
|
||||
if file_type == "image":
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
results, annotated_filename = process_image(content, filename)
|
||||
else:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
results, annotated_filename = process_video(content, filename)
|
||||
|
||||
if results and annotated_filename:
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"results": results,
|
||||
"status": "completed",
|
||||
"annotated_filename": annotated_filename
|
||||
}))
|
||||
else:
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
|
||||
def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@3__:mediapipe_result:*')
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
operation = message['data'].decode('utf-8')
|
||||
|
||||
if operation == 'set':
|
||||
value = redis_client.get(f"mediapipe_result:{key}")
|
||||
if value:
|
||||
result = json.loads(value)
|
||||
print(f"Status update for {key}: {result['status']}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
threading.Thread(target=process_task, daemon=True).start()
|
||||
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
||||
uvicorn.run(app, host="0.0.0.0", port=7006)
|
||||
@@ -1,453 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Header
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from redis import Redis
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import threading
|
||||
import mediapipe as mp
|
||||
from mediapipe.tasks import python
|
||||
from mediapipe.tasks.python import vision
|
||||
|
||||
app = FastAPI()
|
||||
mediapipe_app = FastAPI()
|
||||
app.mount("/mediapipe", mediapipe_app)
|
||||
|
||||
# CORS configuration
|
||||
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Configuration
|
||||
MODEL_PATH = "/home/zydi/models/face_landmarker.task" # Replace with your model path
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "mediapipe"
|
||||
KAFKA_GROUP_ID = "mediapipe_group"
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 10
|
||||
REDIS_API_DB = 12
|
||||
REDIS_API_USAGE_DB = 13
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# Ensure directories exist
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# Initialize Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# Initialize Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
redis_api_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_DB
|
||||
)
|
||||
redis_api_usage_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_USAGE_DB
|
||||
)
|
||||
|
||||
# 添加API密钥验证
|
||||
API_KEY_NAME = "X-API-Key"
|
||||
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
||||
|
||||
async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)):
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=400, detail="API密钥缺失")
|
||||
return api_key
|
||||
|
||||
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||
base_tokens = 0
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||
|
||||
# 基础token:每MB文件大小消耗10个token
|
||||
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
||||
|
||||
if file_type == "image":
|
||||
img = cv2.imread(file_path)
|
||||
if img is None:
|
||||
raise ValueError("无法读取图片文件")
|
||||
height, width = img.shape[:2]
|
||||
pixel_count = height * width
|
||||
|
||||
# 图片token:每100个像素额外消耗5个token
|
||||
image_tokens = int((pixel_count / 10000) * 5)
|
||||
|
||||
base_tokens += image_tokens
|
||||
|
||||
elif file_type == "video":
|
||||
cap = cv2.VideoCapture(file_path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError("无法打开视频文件")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
cap.release()
|
||||
|
||||
pixel_count = width * height * frame_count
|
||||
duration = frame_count / fps # 视频时长(秒)
|
||||
|
||||
# 视频token:每100万像素每秒额外消耗1个token
|
||||
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
||||
|
||||
base_tokens += video_tokens
|
||||
|
||||
return max(1, base_tokens) # 确保至少返回1个token
|
||||
except Exception as e:
|
||||
print(f"计算token时出错: {str(e)}")
|
||||
return 1 # 出错时返回默认值1
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
return None
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
return None
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
return None
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
# 更新总的token使用量
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
# 更新特定模型的token使用量
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
class mediapipeEmbedder:
|
||||
def __init__(self, model_path):
|
||||
base_options = python.BaseOptions(model_asset_path=model_path)
|
||||
options = vision.FaceLandmarkerOptions(
|
||||
base_options=base_options,
|
||||
output_face_blendshapes=True,
|
||||
output_facial_transformation_matrixes=True,
|
||||
num_faces=1
|
||||
)
|
||||
self.detector = vision.FaceLandmarker.create_from_options(options)
|
||||
|
||||
def get_mediapipe_landmarks(self, image):
|
||||
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
|
||||
detection_result = self.detector.detect(mp_image)
|
||||
if detection_result.face_landmarks:
|
||||
return np.array([(lm.x, lm.y, lm.z) for lm in detection_result.face_landmarks[0]])
|
||||
return None
|
||||
|
||||
def process_image(self, image_data):
|
||||
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
landmarks = self.get_mediapipe_landmarks(img)
|
||||
|
||||
if landmarks is not None:
|
||||
# Calculate a more detailed mediapipe embedding
|
||||
embedding = self.calculate_detailed_embedding(landmarks)
|
||||
|
||||
# Draw landmarks on the image
|
||||
for lm in landmarks:
|
||||
cv2.circle(img, (int(lm[0]*img.shape[1]), int(lm[1]*img.shape[0])), 2, (0,255,0), -1)
|
||||
|
||||
return {
|
||||
"embedding": embedding,
|
||||
"landmarks": landmarks.tolist()
|
||||
}, img
|
||||
else:
|
||||
return None, img
|
||||
|
||||
def calculate_detailed_embedding(self, landmarks):
|
||||
# Calculate various statistical features
|
||||
mean = np.mean(landmarks, axis=0)
|
||||
std = np.std(landmarks, axis=0)
|
||||
median = np.median(landmarks, axis=0)
|
||||
min_vals = np.min(landmarks, axis=0)
|
||||
max_vals = np.max(landmarks, axis=0)
|
||||
|
||||
# Calculate pairwise distances between key facial landmarks
|
||||
nose_tip = landmarks[4]
|
||||
left_eye = landmarks[159]
|
||||
right_eye = landmarks[386]
|
||||
left_mouth = landmarks[61]
|
||||
right_mouth = landmarks[291]
|
||||
|
||||
eye_distance = np.linalg.norm(left_eye - right_eye)
|
||||
mouth_width = np.linalg.norm(left_mouth - right_mouth)
|
||||
nose_to_mouth = np.linalg.norm(nose_tip - (left_mouth + right_mouth) / 2)
|
||||
|
||||
# Calculate face shape features
|
||||
face_width = np.max(landmarks[:, 0]) - np.min(landmarks[:, 0])
|
||||
face_height = np.max(landmarks[:, 1]) - np.min(landmarks[:, 1])
|
||||
face_depth = np.max(landmarks[:, 2]) - np.min(landmarks[:, 2])
|
||||
|
||||
# Combine all features into a single embedding
|
||||
embedding = np.concatenate([
|
||||
mean, std, median, min_vals, max_vals,
|
||||
[eye_distance, mouth_width, nose_to_mouth, face_width, face_height, face_depth]
|
||||
])
|
||||
|
||||
return embedding.tolist()
|
||||
|
||||
embedder = mediapipeEmbedder(MODEL_PATH)
|
||||
|
||||
def process_image(image_data, filename):
|
||||
try:
|
||||
results, annotated_img = embedder.process_image(image_data)
|
||||
|
||||
if results:
|
||||
# Save annotated image
|
||||
annotated_filename = f"mediapipe_{filename}"
|
||||
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
cv2.imwrite(annotated_path, annotated_img)
|
||||
|
||||
return results, annotated_filename
|
||||
else:
|
||||
print(f"No face landmarks detected in image: {filename}")
|
||||
return None, None
|
||||
except Exception as e:
|
||||
print(f"Error processing image {filename}: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None, None
|
||||
|
||||
def process_video(video_data, filename):
|
||||
try:
|
||||
temp_video_path = os.path.join(UPLOAD_DIR, f"mediapipe_{filename}")
|
||||
with open(temp_video_path, 'wb') as temp_video:
|
||||
temp_video.write(video_data)
|
||||
|
||||
cap = cv2.VideoCapture(temp_video_path)
|
||||
frame_count = 0
|
||||
results = []
|
||||
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
annotated_filename = f"mediapipe_{filename}"
|
||||
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
if frame_count % fps == 0:
|
||||
frame_results, annotated_frame = embedder.process_image(cv2.imencode('.jpg', frame)[1].tobytes())
|
||||
if frame_results:
|
||||
results.append({"frame": frame_count, "results": frame_results})
|
||||
else:
|
||||
annotated_frame = frame
|
||||
|
||||
out.write(annotated_frame)
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
os.remove(temp_video_path)
|
||||
|
||||
return results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing video: {str(e)}")
|
||||
return None, None
|
||||
|
||||
@mediapipe_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...),api_key: str = Depends(get_api_key)):
|
||||
# 验证 API key
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(original_file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 计算 token
|
||||
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
tokens_required = calculate_tokens(original_file_path, file_type)
|
||||
if tokens_required is None or tokens_required <= 0:
|
||||
raise HTTPException(status_code=500, detail="无法计算所需的token数量")
|
||||
# 检查并更新 token 使用量
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
model_name = "mediapipe"
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": new_filename,
|
||||
"file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
}).encode('utf-8'))
|
||||
|
||||
redis_key = f"mediapipe_result:{new_filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "文件已上传并排队等待处理",
|
||||
"filename": new_filename,
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
})
|
||||
|
||||
|
||||
@mediapipe_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_mediapipe_result(filename: str):
|
||||
redis_key = f"mediapipe_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
return JSONResponse(content=result_data)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
@mediapipe_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_annotated_file(filename: str):
|
||||
redis_key = f"mediapipe_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
if result_data["status"] == "completed":
|
||||
annotated_filename = result_data["annotated_filename"]
|
||||
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
if os.path.exists(file_path):
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
||||
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Annotated file not found")
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
task = message.value
|
||||
filename = task['filename']
|
||||
file_type = task['file_type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
redis_key = f"mediapipe_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
try:
|
||||
if file_type == "image":
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
results, annotated_filename = process_image(content, filename)
|
||||
else:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
results, annotated_filename = process_video(content, filename)
|
||||
|
||||
if results and annotated_filename:
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"results": results,
|
||||
"status": "completed",
|
||||
"annotated_filename": annotated_filename
|
||||
}))
|
||||
else:
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
|
||||
def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@3__:mediapipe_result:*')
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
operation = message['data'].decode('utf-8')
|
||||
|
||||
if operation == 'set':
|
||||
value = redis_client.get(f"mediapipe_result:{key}")
|
||||
if value:
|
||||
result = json.loads(value)
|
||||
print(f"Status update for {key}: {result['status']}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
threading.Thread(target=process_task, daemon=True).start()
|
||||
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
||||
uvicorn.run(app, host="0.0.0.0", port=7006)
|
||||
@@ -1,462 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from redis import Redis
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import threading
|
||||
import mediapipe as mp
|
||||
from mediapipe.tasks import python
|
||||
from mediapipe.tasks.python import vision
|
||||
import string
|
||||
|
||||
app = FastAPI()
|
||||
mediapipe_app = FastAPI()
|
||||
app.mount("/mediapipe", mediapipe_app)
|
||||
|
||||
# CORS configuration
|
||||
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Configuration
|
||||
MODEL_PATH = "/home/zydi/models/face_landmarker.task" # Replace with your model path
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "mediapipe"
|
||||
KAFKA_GROUP_ID = "mediapipe_group"
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 10
|
||||
REDIS_API_DB = 12
|
||||
REDIS_API_USAGE_DB = 13
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# Ensure directories exist
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# Initialize Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# Initialize Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
redis_api_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_DB
|
||||
)
|
||||
redis_api_usage_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_USAGE_DB
|
||||
)
|
||||
|
||||
# 定义API密钥头部
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
# 定义 base62 字符集
|
||||
BASE62 = string.digits + string.ascii_letters
|
||||
|
||||
# 验证API密钥的函数
|
||||
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||
if api_key and api_key.startswith("Bearer "):
|
||||
key = api_key.split(" ")[1]
|
||||
if key.startswith("obs-"):
|
||||
return key
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的API密钥",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||
base_tokens = 0
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||
|
||||
# 基础token:每MB文件大小消耗10个token
|
||||
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
||||
|
||||
if file_type == "image":
|
||||
img = cv2.imread(file_path)
|
||||
if img is None:
|
||||
raise ValueError("无法读取图片文件")
|
||||
height, width = img.shape[:2]
|
||||
pixel_count = height * width
|
||||
|
||||
# 图片token:每100个像素额外消耗5个token
|
||||
image_tokens = int((pixel_count / 10000) * 5)
|
||||
|
||||
base_tokens += image_tokens
|
||||
|
||||
elif file_type == "video":
|
||||
cap = cv2.VideoCapture(file_path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError("无法打开视频文件")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
cap.release()
|
||||
|
||||
pixel_count = width * height * frame_count
|
||||
duration = frame_count / fps # 视频时长(秒)
|
||||
|
||||
# 视频token:每100万像素每秒额外消耗1个token
|
||||
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
||||
|
||||
base_tokens += video_tokens
|
||||
|
||||
return max(1, base_tokens) # 确保至少返回1个token
|
||||
except Exception as e:
|
||||
print(f"计算token时出错: {str(e)}")
|
||||
return 1 # 出错时返回默认值1
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
"api_key": api_key,
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
# 更新总的token使用量
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
# 更新特定模型的token使用量
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
class mediapipeEmbedder:
|
||||
def __init__(self, model_path):
|
||||
base_options = python.BaseOptions(model_asset_path=model_path)
|
||||
options = vision.FaceLandmarkerOptions(
|
||||
base_options=base_options,
|
||||
output_face_blendshapes=True,
|
||||
output_facial_transformation_matrixes=True,
|
||||
num_faces=1
|
||||
)
|
||||
self.detector = vision.FaceLandmarker.create_from_options(options)
|
||||
|
||||
def get_mediapipe_landmarks(self, image):
|
||||
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
|
||||
detection_result = self.detector.detect(mp_image)
|
||||
if detection_result.face_landmarks:
|
||||
return np.array([(lm.x, lm.y, lm.z) for lm in detection_result.face_landmarks[0]])
|
||||
return None
|
||||
|
||||
def process_image(self, image_data):
|
||||
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
landmarks = self.get_mediapipe_landmarks(img)
|
||||
|
||||
if landmarks is not None:
|
||||
# Calculate a more detailed mediapipe embedding
|
||||
embedding = self.calculate_detailed_embedding(landmarks)
|
||||
|
||||
# Draw landmarks on the image
|
||||
for lm in landmarks:
|
||||
cv2.circle(img, (int(lm[0]*img.shape[1]), int(lm[1]*img.shape[0])), 2, (0,255,0), -1)
|
||||
|
||||
return {
|
||||
"embedding": embedding,
|
||||
"landmarks": landmarks.tolist()
|
||||
}, img
|
||||
else:
|
||||
return None, img
|
||||
|
||||
def calculate_detailed_embedding(self, landmarks):
|
||||
# Calculate various statistical features
|
||||
mean = np.mean(landmarks, axis=0)
|
||||
std = np.std(landmarks, axis=0)
|
||||
median = np.median(landmarks, axis=0)
|
||||
min_vals = np.min(landmarks, axis=0)
|
||||
max_vals = np.max(landmarks, axis=0)
|
||||
|
||||
# Calculate pairwise distances between key facial landmarks
|
||||
nose_tip = landmarks[4]
|
||||
left_eye = landmarks[159]
|
||||
right_eye = landmarks[386]
|
||||
left_mouth = landmarks[61]
|
||||
right_mouth = landmarks[291]
|
||||
|
||||
eye_distance = np.linalg.norm(left_eye - right_eye)
|
||||
mouth_width = np.linalg.norm(left_mouth - right_mouth)
|
||||
nose_to_mouth = np.linalg.norm(nose_tip - (left_mouth + right_mouth) / 2)
|
||||
|
||||
# Calculate face shape features
|
||||
face_width = np.max(landmarks[:, 0]) - np.min(landmarks[:, 0])
|
||||
face_height = np.max(landmarks[:, 1]) - np.min(landmarks[:, 1])
|
||||
face_depth = np.max(landmarks[:, 2]) - np.min(landmarks[:, 2])
|
||||
|
||||
# Combine all features into a single embedding
|
||||
embedding = np.concatenate([
|
||||
mean, std, median, min_vals, max_vals,
|
||||
[eye_distance, mouth_width, nose_to_mouth, face_width, face_height, face_depth]
|
||||
])
|
||||
|
||||
return embedding.tolist()
|
||||
|
||||
embedder = mediapipeEmbedder(MODEL_PATH)
|
||||
|
||||
def process_image(image_data, filename):
|
||||
try:
|
||||
results, annotated_img = embedder.process_image(image_data)
|
||||
|
||||
if results:
|
||||
# Save annotated image
|
||||
annotated_filename = f"mediapipe_{filename}"
|
||||
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
cv2.imwrite(annotated_path, annotated_img)
|
||||
|
||||
return results, annotated_filename
|
||||
else:
|
||||
print(f"No face landmarks detected in image: {filename}")
|
||||
return None, None
|
||||
except Exception as e:
|
||||
print(f"Error processing image {filename}: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None, None
|
||||
|
||||
def process_video(video_data, filename):
|
||||
try:
|
||||
temp_video_path = os.path.join(UPLOAD_DIR, f"mediapipe_{filename}")
|
||||
with open(temp_video_path, 'wb') as temp_video:
|
||||
temp_video.write(video_data)
|
||||
|
||||
cap = cv2.VideoCapture(temp_video_path)
|
||||
frame_count = 0
|
||||
results = []
|
||||
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
annotated_filename = f"mediapipe_{filename}"
|
||||
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
if frame_count % fps == 0:
|
||||
frame_results, annotated_frame = embedder.process_image(cv2.imencode('.jpg', frame)[1].tobytes())
|
||||
if frame_results:
|
||||
results.append({"frame": frame_count, "results": frame_results})
|
||||
else:
|
||||
annotated_frame = frame
|
||||
|
||||
out.write(annotated_frame)
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
os.remove(temp_video_path)
|
||||
|
||||
return results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing video: {str(e)}")
|
||||
return None, None
|
||||
|
||||
@mediapipe_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(original_file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 计算 token
|
||||
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
tokens_required = calculate_tokens(original_file_path, file_type)
|
||||
if tokens_required is None or tokens_required <= 0:
|
||||
raise HTTPException(status_code=500, detail="无法计算所需的token数量")
|
||||
|
||||
# 检查并更新 token 使用量
|
||||
api_key = api_key_info['api_key']
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
model_name = "mediapipe"
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": new_filename,
|
||||
"file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
}).encode('utf-8'))
|
||||
|
||||
redis_key = f"mediapipe_result:{new_filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "文件已上传并排队等待处理",
|
||||
"filename": new_filename,
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
})
|
||||
|
||||
|
||||
@mediapipe_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_mediapipe_result(filename: str):
|
||||
redis_key = f"mediapipe_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
return JSONResponse(content=result_data)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
@mediapipe_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_annotated_file(filename: str):
|
||||
redis_key = f"mediapipe_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
if result_data["status"] == "completed":
|
||||
annotated_filename = result_data["annotated_filename"]
|
||||
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
if os.path.exists(file_path):
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
||||
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Annotated file not found")
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
task = message.value
|
||||
filename = task['filename']
|
||||
file_type = task['file_type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
redis_key = f"mediapipe_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
try:
|
||||
if file_type == "image":
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
results, annotated_filename = process_image(content, filename)
|
||||
else:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
results, annotated_filename = process_video(content, filename)
|
||||
|
||||
if results and annotated_filename:
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"results": results,
|
||||
"status": "completed",
|
||||
"annotated_filename": annotated_filename
|
||||
}))
|
||||
else:
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
|
||||
def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@3__:mediapipe_result:*')
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
operation = message['data'].decode('utf-8')
|
||||
|
||||
if operation == 'set':
|
||||
value = redis_client.get(f"mediapipe_result:{key}")
|
||||
if value:
|
||||
result = json.loads(value)
|
||||
print(f"Status update for {key}: {result['status']}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
threading.Thread(target=process_task, daemon=True).start()
|
||||
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
||||
uvicorn.run(app, host="0.0.0.0", port=7006)
|
||||
@@ -1,312 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from redis import Redis
|
||||
import io
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
import threading
|
||||
import torch
|
||||
torch.cuda.set_device(1)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
pose_app = FastAPI()
|
||||
app.mount("/pose", pose_app)
|
||||
|
||||
# CORS配置
|
||||
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
||||
|
||||
# 只为主应用添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/models/yolov8x-pose.pt" # 请替换为您的模型路径
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "pose" # 指定Kafka topic
|
||||
KAFKA_GROUP_ID = "pose_group" # 指定消费者组ID
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 3
|
||||
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
class PoseDetector:
|
||||
def __init__(self, model_path):
|
||||
self.model = YOLO(model_path).to('cuda:1')
|
||||
|
||||
def detect(self, frame):
|
||||
results = self.model(frame, device='cuda:1')
|
||||
return results
|
||||
|
||||
def format_results(self, results):
|
||||
formatted_results = []
|
||||
for r in results:
|
||||
boxes = r.boxes
|
||||
keypoints = r.keypoints
|
||||
for i in range(len(boxes)):
|
||||
box = boxes[i]
|
||||
kpts = keypoints[i]
|
||||
formatted_results.append({
|
||||
"bbox": box.xyxy.tolist()[0],
|
||||
"confidence": box.conf.item(),
|
||||
"keypoints": kpts.xy.tolist()[0]
|
||||
})
|
||||
return formatted_results
|
||||
|
||||
def draw_results(self, frame, results, original_shape):
|
||||
for r in results:
|
||||
annotated_frame = r.plot(img=frame)
|
||||
# 调整坐标以适应原始图像大小
|
||||
h, w = annotated_frame.shape[:2]
|
||||
scale_x, scale_y = original_shape[1] / w, original_shape[0] / h
|
||||
annotated_frame = cv2.resize(annotated_frame, (original_shape[1], original_shape[0]))
|
||||
return annotated_frame
|
||||
|
||||
detector = PoseDetector(MODEL_PATH)
|
||||
|
||||
def process_image(image_data, filename):
|
||||
try:
|
||||
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
original_shape = img.shape
|
||||
# Convert BGR to RGB
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize image to fit model requirements (640x640)
|
||||
img_resized = cv2.resize(img, (640, 640))
|
||||
|
||||
# Normalize and reshape to BCHW format
|
||||
img_tensor = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||
img_tensor = img_tensor.to('cuda:1')
|
||||
|
||||
results = detector.detect(img_tensor)
|
||||
|
||||
# Format results for JSON
|
||||
json_results = detector.format_results(results)
|
||||
|
||||
# Draw results on original image
|
||||
annotated_img = detector.draw_results(img_resized, results, original_shape)
|
||||
|
||||
# Save annotated image
|
||||
annotated_filename = f"pose_{filename}"
|
||||
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
cv2.imwrite(annotated_path, cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR))
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def process_video(video_data, filename):
|
||||
try:
|
||||
temp_video_path = os.path.join(UPLOAD_DIR, f"pose_{filename}")
|
||||
with open(temp_video_path, 'wb') as temp_video:
|
||||
temp_video.write(video_data)
|
||||
|
||||
cap = cv2.VideoCapture(temp_video_path)
|
||||
frame_count = 0
|
||||
json_results = []
|
||||
|
||||
# Get video properties
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
original_shape = (height, width)
|
||||
|
||||
# Create output video file
|
||||
annotated_filename = f"pose_{filename}"
|
||||
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# Process one frame per second
|
||||
if frame_count % fps == 0:
|
||||
# Convert BGR to RGB
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize frame to fit model requirements (640x640)
|
||||
frame_resized = cv2.resize(frame_rgb, (640, 640))
|
||||
|
||||
# Normalize and reshape to BCHW format
|
||||
frame_tensor = torch.from_numpy(frame_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||
frame_tensor = frame_tensor.to('cuda:1')
|
||||
|
||||
results = detector.detect(frame_tensor)
|
||||
frame_json_results = detector.format_results(results)
|
||||
json_results.append({"frame": frame_count, "detections": frame_json_results})
|
||||
|
||||
# Draw results on original frame
|
||||
annotated_frame = detector.draw_results(frame_resized, results, original_shape)
|
||||
# Convert RGB back to BGR for OpenCV
|
||||
annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
annotated_frame = frame
|
||||
|
||||
out.write(annotated_frame)
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
# Clean up temporary input video file
|
||||
os.remove(temp_video_path)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing video: {str(e)}")
|
||||
return None, None
|
||||
|
||||
@pose_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# Save the original file
|
||||
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(original_file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Send processing task to Kafka
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": new_filename,
|
||||
"file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
}).encode('utf-8'))
|
||||
|
||||
# Set initial status in Redis
|
||||
redis_key = f"pose_result:{new_filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
return JSONResponse(content={"message": "File uploaded and queued for processing", "filename": new_filename})
|
||||
|
||||
@pose_app.get("/result/{filename}")
|
||||
async def get_pose_result(filename: str):
|
||||
redis_key = f"pose_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
return JSONResponse(content=result_data) # 直接返回整个结果,包括 status
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
@pose_app.get("/annotated/{filename}")
|
||||
async def get_annotated_file(filename: str):
|
||||
redis_key = f"pose_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
if result_data["status"] == "completed":
|
||||
annotated_filename = result_data["annotated_filename"]
|
||||
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
if os.path.exists(file_path):
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
||||
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Annotated file not found")
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
task = message.value
|
||||
filename = task['filename']
|
||||
file_type = task['file_type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
# Update status to "processing"
|
||||
redis_key = f"pose_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
try:
|
||||
if file_type == "image":
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_image(content, filename)
|
||||
else:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_video(content, filename)
|
||||
|
||||
if json_results and annotated_filename:
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"json_results": json_results,
|
||||
"status": "completed",
|
||||
"annotated_filename": annotated_filename
|
||||
}))
|
||||
else:
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
|
||||
def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@3__:pose_result:*') # 监听所有pose_result键的变化
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
operation = message['data'].decode('utf-8')
|
||||
|
||||
if operation == 'set':
|
||||
value = redis_client.get(f"pose_result:{key}")
|
||||
if value:
|
||||
result = json.loads(value)
|
||||
print(f"Status update for {key}: {result['status']}")
|
||||
|
||||
# 这里可以添加其他处理逻辑,比如发送通知等
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动处理任务的线程
|
||||
threading.Thread(target=process_task, daemon=True).start()
|
||||
|
||||
# 启动Redis监听线程
|
||||
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=7001)
|
||||
@@ -1,461 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Header
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from redis import Redis
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import threading
|
||||
import torch
|
||||
torch.cuda.set_device(1)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
pose_app = FastAPI()
|
||||
app.mount("/pose", pose_app)
|
||||
|
||||
# CORS配置
|
||||
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
||||
|
||||
# 只为主应用添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/models/yolov8x-pose.pt" # 请替换为您的模型路径
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "pose" # 指定Kafka topic
|
||||
KAFKA_GROUP_ID = "pose_group" # 指定消费者组ID
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 3
|
||||
REDIS_API_DB = 12
|
||||
REDIS_API_USAGE_DB = 13
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
redis_api_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_DB
|
||||
)
|
||||
redis_api_usage_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_USAGE_DB
|
||||
)
|
||||
|
||||
# 添加API密钥验证
|
||||
API_KEY_NAME = "X-API-Key"
|
||||
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
||||
|
||||
async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)):
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=400, detail="API密钥缺失")
|
||||
return api_key
|
||||
|
||||
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||
base_tokens = 0
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||
|
||||
# 基础token:每MB文件大小消耗10个token
|
||||
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
||||
|
||||
if file_type == "image":
|
||||
img = cv2.imread(file_path)
|
||||
if img is None:
|
||||
raise ValueError("无法读取图片文件")
|
||||
height, width = img.shape[:2]
|
||||
pixel_count = height * width
|
||||
|
||||
# 图片token:每100个像素额外消耗5个token
|
||||
image_tokens = int((pixel_count / 10000) * 5)
|
||||
|
||||
base_tokens += image_tokens
|
||||
|
||||
elif file_type == "video":
|
||||
cap = cv2.VideoCapture(file_path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError("无法打开视频文件")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
cap.release()
|
||||
|
||||
pixel_count = width * height * frame_count
|
||||
duration = frame_count / fps # 视频时长(秒)
|
||||
|
||||
# 视频token:每100万像素每秒额外消耗1个token
|
||||
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
||||
|
||||
base_tokens += video_tokens
|
||||
|
||||
return max(1, base_tokens) # 确保至少返回1个token
|
||||
except Exception as e:
|
||||
print(f"计算token时出错: {str(e)}")
|
||||
return 1 # 出错时返回默认值1
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
return None
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
return None
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
return None
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
# 更新总的token使用量
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
# 更新特定模型的token使用量
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
class PoseDetector:
|
||||
def __init__(self, model_path):
|
||||
self.model = YOLO(model_path).to('cuda:1')
|
||||
|
||||
def detect(self, frame):
|
||||
results = self.model(frame, device='cuda:1')
|
||||
return results
|
||||
|
||||
def format_results(self, results):
|
||||
formatted_results = []
|
||||
for r in results:
|
||||
boxes = r.boxes
|
||||
keypoints = r.keypoints
|
||||
for i in range(len(boxes)):
|
||||
box = boxes[i]
|
||||
kpts = keypoints[i]
|
||||
formatted_results.append({
|
||||
"bbox": box.xyxy.tolist()[0],
|
||||
"confidence": box.conf.item(),
|
||||
"keypoints": kpts.xy.tolist()[0]
|
||||
})
|
||||
return formatted_results
|
||||
|
||||
def draw_results(self, frame, results, original_shape):
|
||||
for r in results:
|
||||
annotated_frame = r.plot(img=frame)
|
||||
# 调整坐标以适应原始图像大小
|
||||
h, w = annotated_frame.shape[:2]
|
||||
scale_x, scale_y = original_shape[1] / w, original_shape[0] / h
|
||||
annotated_frame = cv2.resize(annotated_frame, (original_shape[1], original_shape[0]))
|
||||
return annotated_frame
|
||||
|
||||
detector = PoseDetector(MODEL_PATH)
|
||||
|
||||
def process_image(image_data, filename):
|
||||
try:
|
||||
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
original_shape = img.shape
|
||||
# Convert BGR to RGB
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize image to fit model requirements (640x640)
|
||||
img_resized = cv2.resize(img, (640, 640))
|
||||
|
||||
# Normalize and reshape to BCHW format
|
||||
img_tensor = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||
img_tensor = img_tensor.to('cuda:1')
|
||||
|
||||
results = detector.detect(img_tensor)
|
||||
|
||||
# Format results for JSON
|
||||
json_results = detector.format_results(results)
|
||||
|
||||
# Draw results on original image
|
||||
annotated_img = detector.draw_results(img_resized, results, original_shape)
|
||||
|
||||
# Save annotated image
|
||||
annotated_filename = f"pose_{filename}"
|
||||
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
cv2.imwrite(annotated_path, cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR))
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def process_video(video_data, filename):
|
||||
try:
|
||||
temp_video_path = os.path.join(UPLOAD_DIR, f"pose_{filename}")
|
||||
with open(temp_video_path, 'wb') as temp_video:
|
||||
temp_video.write(video_data)
|
||||
|
||||
cap = cv2.VideoCapture(temp_video_path)
|
||||
frame_count = 0
|
||||
json_results = []
|
||||
|
||||
# Get video properties
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
original_shape = (height, width)
|
||||
|
||||
# Create output video file
|
||||
annotated_filename = f"pose_{filename}"
|
||||
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# Process one frame per second
|
||||
if frame_count % fps == 0:
|
||||
# Convert BGR to RGB
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize frame to fit model requirements (640x640)
|
||||
frame_resized = cv2.resize(frame_rgb, (640, 640))
|
||||
|
||||
# Normalize and reshape to BCHW format
|
||||
frame_tensor = torch.from_numpy(frame_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||
frame_tensor = frame_tensor.to('cuda:1')
|
||||
|
||||
results = detector.detect(frame_tensor)
|
||||
frame_json_results = detector.format_results(results)
|
||||
json_results.append({"frame": frame_count, "detections": frame_json_results})
|
||||
|
||||
# Draw results on original frame
|
||||
annotated_frame = detector.draw_results(frame_resized, results, original_shape)
|
||||
# Convert RGB back to BGR for OpenCV
|
||||
annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
annotated_frame = frame
|
||||
|
||||
out.write(annotated_frame)
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
# Clean up temporary input video file
|
||||
os.remove(temp_video_path)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing video: {str(e)}")
|
||||
return None, None
|
||||
|
||||
@pose_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...),api_key: str = Depends(get_api_key)):
|
||||
# 验证 API key
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# Save the original file
|
||||
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(original_file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 计算 token
|
||||
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
tokens_required = calculate_tokens(original_file_path, file_type)
|
||||
if tokens_required is None or tokens_required <= 0:
|
||||
raise HTTPException(status_code=500, detail="无法计算所需的token数量")
|
||||
|
||||
# 检查并更新 token 使用量
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
model_name = "yolov8x-pose"
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
|
||||
# Send processing task to Kafka
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": new_filename,
|
||||
"file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
}).encode('utf-8'))
|
||||
|
||||
# Set initial status in Redis
|
||||
redis_key = f"pose_result:{new_filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "文件已上传并排队等待处理",
|
||||
"filename": new_filename,
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
})
|
||||
|
||||
|
||||
@pose_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_pose_result(filename: str):
|
||||
redis_key = f"pose_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
return JSONResponse(content=result_data) # 直接返回整个结果,包括 status
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
@pose_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_annotated_file(filename: str):
|
||||
redis_key = f"pose_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
if result_data["status"] == "completed":
|
||||
annotated_filename = result_data["annotated_filename"]
|
||||
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
if os.path.exists(file_path):
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
||||
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Annotated file not found")
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
task = message.value
|
||||
filename = task['filename']
|
||||
file_type = task['file_type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
# Update status to "processing"
|
||||
redis_key = f"pose_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
try:
|
||||
if file_type == "image":
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_image(content, filename)
|
||||
else:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_video(content, filename)
|
||||
|
||||
if json_results and annotated_filename:
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"json_results": json_results,
|
||||
"status": "completed",
|
||||
"annotated_filename": annotated_filename
|
||||
}))
|
||||
else:
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
|
||||
def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@3__:pose_result:*') # 监听所有pose_result键的变化
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
operation = message['data'].decode('utf-8')
|
||||
|
||||
if operation == 'set':
|
||||
value = redis_client.get(f"pose_result:{key}")
|
||||
if value:
|
||||
result = json.loads(value)
|
||||
print(f"Status update for {key}: {result['status']}")
|
||||
|
||||
# 这里可以添加其他处理逻辑,比如发送通知等
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动处理任务的线程
|
||||
threading.Thread(target=process_task, daemon=True).start()
|
||||
|
||||
# 启动Redis监听线程
|
||||
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=7001)
|
||||
@@ -1,470 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from redis import Redis
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import threading
|
||||
import torch
|
||||
import string
|
||||
|
||||
torch.cuda.set_device(1)
|
||||
|
||||
app = FastAPI()
|
||||
pose_app = FastAPI()
|
||||
app.mount("/pose", pose_app)
|
||||
|
||||
# CORS配置
|
||||
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
||||
|
||||
# 只为主应用添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/models/yolov8x-pose.pt" # 请替换为您的模型路径
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "pose" # 指定Kafka topic
|
||||
KAFKA_GROUP_ID = "pose_group" # 指定消费者组ID
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 3
|
||||
REDIS_API_DB = 12
|
||||
REDIS_API_USAGE_DB = 13
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
redis_api_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_DB
|
||||
)
|
||||
redis_api_usage_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_USAGE_DB
|
||||
)
|
||||
|
||||
# 定义API密钥头部
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
# 定义 base62 字符集
|
||||
BASE62 = string.digits + string.ascii_letters
|
||||
|
||||
# 验证API密钥的函数
|
||||
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||
if api_key and api_key.startswith("Bearer "):
|
||||
key = api_key.split(" ")[1]
|
||||
if key.startswith("obs-"):
|
||||
return key
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的API密钥",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||
base_tokens = 0
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||
|
||||
# 基础token:每MB文件大小消耗10个token
|
||||
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
||||
|
||||
if file_type == "image":
|
||||
img = cv2.imread(file_path)
|
||||
if img is None:
|
||||
raise ValueError("无法读取图片文件")
|
||||
height, width = img.shape[:2]
|
||||
pixel_count = height * width
|
||||
|
||||
# 图片token:每100个像素额外消耗5个token
|
||||
image_tokens = int((pixel_count / 10000) * 5)
|
||||
|
||||
base_tokens += image_tokens
|
||||
|
||||
elif file_type == "video":
|
||||
cap = cv2.VideoCapture(file_path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError("无法打开视频文件")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
cap.release()
|
||||
|
||||
pixel_count = width * height * frame_count
|
||||
duration = frame_count / fps # 视频时长(秒)
|
||||
|
||||
# 视频token:每100万像素每秒额外消耗1个token
|
||||
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
||||
|
||||
base_tokens += video_tokens
|
||||
|
||||
return max(1, base_tokens) # 确保至少返回1个token
|
||||
except Exception as e:
|
||||
print(f"计算token时出错: {str(e)}")
|
||||
return 1 # 出错时返回默认值1
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
"api_key": api_key,
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
# 更新总的token使用量
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
# 更新特定模型的token使用量
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
class PoseDetector:
|
||||
def __init__(self, model_path):
|
||||
self.model = YOLO(model_path).to('cuda:1')
|
||||
|
||||
def detect(self, frame):
|
||||
results = self.model(frame, device='cuda:1')
|
||||
return results
|
||||
|
||||
def format_results(self, results):
|
||||
formatted_results = []
|
||||
for r in results:
|
||||
boxes = r.boxes
|
||||
keypoints = r.keypoints
|
||||
for i in range(len(boxes)):
|
||||
box = boxes[i]
|
||||
kpts = keypoints[i]
|
||||
formatted_results.append({
|
||||
"bbox": box.xyxy.tolist()[0],
|
||||
"confidence": box.conf.item(),
|
||||
"keypoints": kpts.xy.tolist()[0]
|
||||
})
|
||||
return formatted_results
|
||||
|
||||
def draw_results(self, frame, results, original_shape):
|
||||
for r in results:
|
||||
annotated_frame = r.plot(img=frame)
|
||||
# 调整坐标以适应原始图像大小
|
||||
h, w = annotated_frame.shape[:2]
|
||||
scale_x, scale_y = original_shape[1] / w, original_shape[0] / h
|
||||
annotated_frame = cv2.resize(annotated_frame, (original_shape[1], original_shape[0]))
|
||||
return annotated_frame
|
||||
|
||||
detector = PoseDetector(MODEL_PATH)
|
||||
|
||||
def process_image(image_data, filename):
|
||||
try:
|
||||
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
original_shape = img.shape
|
||||
# Convert BGR to RGB
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize image to fit model requirements (640x640)
|
||||
img_resized = cv2.resize(img, (640, 640))
|
||||
|
||||
# Normalize and reshape to BCHW format
|
||||
img_tensor = torch.from_numpy(img_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||
img_tensor = img_tensor.to('cuda:1')
|
||||
|
||||
results = detector.detect(img_tensor)
|
||||
|
||||
# Format results for JSON
|
||||
json_results = detector.format_results(results)
|
||||
|
||||
# Draw results on original image
|
||||
annotated_img = detector.draw_results(img_resized, results, original_shape)
|
||||
|
||||
# Save annotated image
|
||||
annotated_filename = f"pose_{filename}"
|
||||
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
cv2.imwrite(annotated_path, cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR))
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def process_video(video_data, filename):
|
||||
try:
|
||||
temp_video_path = os.path.join(UPLOAD_DIR, f"pose_{filename}")
|
||||
with open(temp_video_path, 'wb') as temp_video:
|
||||
temp_video.write(video_data)
|
||||
|
||||
cap = cv2.VideoCapture(temp_video_path)
|
||||
frame_count = 0
|
||||
json_results = []
|
||||
|
||||
# Get video properties
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
original_shape = (height, width)
|
||||
|
||||
# Create output video file
|
||||
annotated_filename = f"pose_{filename}"
|
||||
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# Process one frame per second
|
||||
if frame_count % fps == 0:
|
||||
# Convert BGR to RGB
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize frame to fit model requirements (640x640)
|
||||
frame_resized = cv2.resize(frame_rgb, (640, 640))
|
||||
|
||||
# Normalize and reshape to BCHW format
|
||||
frame_tensor = torch.from_numpy(frame_resized).float().permute(2, 0, 1).unsqueeze(0) / 255.0
|
||||
frame_tensor = frame_tensor.to('cuda:1')
|
||||
|
||||
results = detector.detect(frame_tensor)
|
||||
frame_json_results = detector.format_results(results)
|
||||
json_results.append({"frame": frame_count, "detections": frame_json_results})
|
||||
|
||||
# Draw results on original frame
|
||||
annotated_frame = detector.draw_results(frame_resized, results, original_shape)
|
||||
# Convert RGB back to BGR for OpenCV
|
||||
annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
annotated_frame = frame
|
||||
|
||||
out.write(annotated_frame)
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
# Clean up temporary input video file
|
||||
os.remove(temp_video_path)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"Error processing video: {str(e)}")
|
||||
return None, None
|
||||
|
||||
@pose_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# Save the original file
|
||||
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(original_file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 计算 token
|
||||
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
tokens_required = calculate_tokens(original_file_path, file_type)
|
||||
if tokens_required is None or tokens_required <= 0:
|
||||
raise HTTPException(status_code=500, detail="无法计算所需的token数量")
|
||||
|
||||
# 检查并更新 token 使用量
|
||||
api_key = api_key_info['api_key']
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
model_name = "yolov8x-pose"
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
|
||||
# Send processing task to Kafka
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": new_filename,
|
||||
"file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
}).encode('utf-8'))
|
||||
|
||||
# Set initial status in Redis
|
||||
redis_key = f"pose_result:{new_filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "文件已上传并排队等待处理",
|
||||
"filename": new_filename,
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
})
|
||||
|
||||
|
||||
@pose_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_pose_result(filename: str):
|
||||
redis_key = f"pose_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
return JSONResponse(content=result_data) # 直接返回整个结果,包括 status
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
@pose_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_annotated_file(filename: str):
|
||||
redis_key = f"pose_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
if result_data["status"] == "completed":
|
||||
annotated_filename = result_data["annotated_filename"]
|
||||
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
if os.path.exists(file_path):
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
||||
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Annotated file not found")
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
task = message.value
|
||||
filename = task['filename']
|
||||
file_type = task['file_type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
# Update status to "processing"
|
||||
redis_key = f"pose_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
try:
|
||||
if file_type == "image":
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_image(content, filename)
|
||||
else:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_video(content, filename)
|
||||
|
||||
if json_results and annotated_filename:
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"json_results": json_results,
|
||||
"status": "completed",
|
||||
"annotated_filename": annotated_filename
|
||||
}))
|
||||
else:
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
|
||||
def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@3__:pose_result:*') # 监听所有pose_result键的变化
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
operation = message['data'].decode('utf-8')
|
||||
|
||||
if operation == 'set':
|
||||
value = redis_client.get(f"pose_result:{key}")
|
||||
if value:
|
||||
result = json.loads(value)
|
||||
print(f"Status update for {key}: {result['status']}")
|
||||
|
||||
# 这里可以添加其他处理逻辑,比如发送通知等
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动处理任务的线程
|
||||
threading.Thread(target=process_task, daemon=True).start()
|
||||
|
||||
# 启动Redis监听线程
|
||||
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=7001)
|
||||
@@ -1,356 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
from redis import Redis
|
||||
import io
|
||||
import re
|
||||
import torch
|
||||
from contextlib import asynccontextmanager
|
||||
import threading
|
||||
|
||||
app = FastAPI()
|
||||
qwenvl_app = FastAPI()
|
||||
app.mount("/qwenvl", qwenvl_app)
|
||||
torch.cuda.set_device(1)
|
||||
# CORS设置
|
||||
ALLOWED_ORIGINS = ['https://beta.obscura.work']
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/models/qwen/Qwen2-VL-2B-Instruct"
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "qwenvl"
|
||||
KAFKA_GROUP_ID = "qwenvl_group"
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 8
|
||||
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
|
||||
# 初始化模型
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
MODEL_PATH, torch_dtype="auto", device_map="cuda:1"
|
||||
)
|
||||
|
||||
min_pixels = 128*28*28
|
||||
max_pixels = 512*28*28
|
||||
processor = AutoProcessor.from_pretrained(MODEL_PATH, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
|
||||
class MediaAnalysisSystem:
|
||||
def __init__(self, model, processor):
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.MAX_NUM_FRAMES = 10
|
||||
|
||||
def encode_video(self, video_data):
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
return [l[int(i * gap + gap / 2)] for i in range(n)]
|
||||
|
||||
video_file = io.BytesIO(video_data)
|
||||
vr = VideoReader(video_file)
|
||||
sample_fps = round(vr.get_avg_fps() / 1)
|
||||
frame_idx = list(range(0, len(vr), sample_fps))
|
||||
if len(frame_idx) > self.MAX_NUM_FRAMES:
|
||||
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
||||
print('num frames:', len(frames))
|
||||
return frames
|
||||
|
||||
def process_media(self, media_data, object_name, media_type='image'):
|
||||
if not media_data:
|
||||
raise ValueError(f"Empty {media_type} data for {object_name}")
|
||||
|
||||
print(f"Processing {media_type}: {object_name}, data size: {len(media_data)} bytes")
|
||||
|
||||
if media_type == 'video':
|
||||
frames = self.encode_video(media_data)
|
||||
media_content = {"type": "video", "video": frames, "fps": 1.0}
|
||||
else: # image
|
||||
image = Image.open(io.BytesIO(media_data))
|
||||
media_content = {"type": "image", "image": image}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
media_content,
|
||||
{"type": "text", "text": "用中文尽可能详细地描述这个" + ("视频" if media_type == "video" else "图片") + ",包括场景、人物数量、行为变化等。"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
text = self.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to('cuda:1')
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=512)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
answer = self.processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
|
||||
extracted_info = self.extract_info(answer)
|
||||
|
||||
result = {
|
||||
"original_answer": answer,
|
||||
"extracted_info": extracted_info,
|
||||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
if media_type == 'video':
|
||||
result["num_frames"] = len(frames)
|
||||
|
||||
return result
|
||||
|
||||
def process_video(self, video_data, object_name):
|
||||
return self.process_media(video_data, object_name, media_type='video')
|
||||
|
||||
def process_image(self, image_data, object_name):
|
||||
return self.process_media(image_data, object_name, media_type='image')
|
||||
|
||||
@staticmethod
|
||||
def extract_time_from_filename(object_name):
|
||||
filename = os.path.basename(object_name)
|
||||
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
|
||||
|
||||
try:
|
||||
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
|
||||
end_time = start_time + timedelta(seconds=10)
|
||||
return start_time, end_time
|
||||
except ValueError:
|
||||
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
|
||||
return datetime.now(), datetime.now() + timedelta(seconds=10)
|
||||
|
||||
@staticmethod
|
||||
def extract_info(answer):
|
||||
info = {
|
||||
"environment": None,
|
||||
"num_people": None,
|
||||
"actions": [],
|
||||
"interactions": [],
|
||||
"objects": [],
|
||||
"furniture": []
|
||||
}
|
||||
|
||||
environments = ["办公室", "室内", "室外", "会议室"]
|
||||
for env in environments:
|
||||
if env in answer.lower():
|
||||
info["environment"] = env
|
||||
break
|
||||
|
||||
people_patterns = [
|
||||
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
||||
r'(男|女)(性|生|士)',
|
||||
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
||||
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
||||
r'(群众|民众|大众|公众)',
|
||||
r'(男女|老少|老幼|大人|小孩)'
|
||||
]
|
||||
for pattern in people_patterns:
|
||||
match = re.search(pattern, answer)
|
||||
if match:
|
||||
if match.group(1).isdigit():
|
||||
info["num_people"] = int(match.group(1))
|
||||
elif match.group(1) in ['一个', '一']:
|
||||
info["num_people"] = 1
|
||||
else:
|
||||
num_word_to_digit = {
|
||||
'二': 2, '三': 3, '四': 4, '五': 5,
|
||||
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
||||
}
|
||||
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
||||
break
|
||||
|
||||
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
||||
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
||||
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
||||
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
||||
|
||||
for item_list, key in [(actions, "actions"), (interactions, "interactions"), (objects, "objects"), (furniture, "furniture")]:
|
||||
for item in item_list:
|
||||
if item in answer:
|
||||
info[key].append(item)
|
||||
|
||||
return info
|
||||
|
||||
# 初始化 MediaAnalysisSystem
|
||||
media_analysis_system = MediaAnalysisSystem(model, processor)
|
||||
|
||||
async def process_file(file: UploadFile, file_type: str):
|
||||
content = await file.read()
|
||||
# 获取原始文件的后缀
|
||||
original_extension = os.path.splitext(file.filename)[1]
|
||||
|
||||
# 生成新的文件名,包含 UUID 和原始后缀
|
||||
filename = f"qwenvl_{uuid.uuid4()}{original_extension}"
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": filename,
|
||||
"type": file_type
|
||||
}).encode('utf-8'))
|
||||
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
return {"message": f"{file_type.capitalize()} uploaded and queued for processing", "filename": filename}
|
||||
|
||||
@qwenvl_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
try:
|
||||
file_type = "image" if file.content_type.startswith("image") else "video"
|
||||
return await process_file(file, file_type)
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
@qwenvl_app.post("/analyze_video")
|
||||
async def analyze_video(file: UploadFile = File(...)):
|
||||
try:
|
||||
return await process_file(file, "video")
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
@qwenvl_app.post("/analyze_image")
|
||||
async def analyze_image(file: UploadFile = File(...)):
|
||||
try:
|
||||
return await process_file(file, "image")
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
try:
|
||||
if isinstance(message.value, dict):
|
||||
task = message.value
|
||||
else:
|
||||
task = json.loads(message.value.decode('utf-8'))
|
||||
|
||||
filename = task['filename']
|
||||
file_type = task['type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
file_data = f.read()
|
||||
|
||||
if file_type == "video":
|
||||
result = media_analysis_system.process_video(file_data, filename)
|
||||
elif file_type == "image":
|
||||
result = media_analysis_system.process_image(file_data, filename)
|
||||
|
||||
# 保存结果到 JSON 文件
|
||||
result_file_path = os.path.join(RESULT_DIR, f"{filename}.json")
|
||||
with open(result_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 将结果存储在 Redis 中
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"status": "completed",
|
||||
"result": result
|
||||
}))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
if 'filename' in locals() and 'file_type' in locals():
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
else:
|
||||
print("Error occurred before task details were extracted")
|
||||
|
||||
@qwenvl_app.get("/result/{filename}")
|
||||
async def get_result(filename: str):
|
||||
for file_type in ["video", "image"]:
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_json = json.loads(result)
|
||||
|
||||
if result_json.get("status") == "queued":
|
||||
return {"status": "queued", "message": "Your request is in the queue and will be processed soon."}
|
||||
elif result_json.get("status") == "processing":
|
||||
return {"status": "processing", "message": "Your request is being processed."}
|
||||
else:
|
||||
return result_json
|
||||
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
async def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@5__:*_result:*')
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
print(f"Key changed: {key}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 在后台线程中启动Kafka消费者
|
||||
consumer_thread = threading.Thread(target=process_task, daemon=True)
|
||||
consumer_thread.start()
|
||||
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=7005)
|
||||
@@ -1,508 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, Header
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
from redis import Redis
|
||||
import io
|
||||
import re
|
||||
import torch
|
||||
from contextlib import asynccontextmanager
|
||||
import threading
|
||||
|
||||
app = FastAPI()
|
||||
qwenvl_app = FastAPI()
|
||||
app.mount("/qwenvl", qwenvl_app)
|
||||
torch.cuda.set_device(1)
|
||||
|
||||
# CORS设置
|
||||
ALLOWED_ORIGINS = ['https://beta.obscura.work']
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/models/qwen/Qwen2-VL-2B-Instruct"
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "qwenvl"
|
||||
KAFKA_GROUP_ID = "qwenvl_group"
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 8
|
||||
REDIS_API_DB = 12
|
||||
REDIS_API_USAGE_DB = 13
|
||||
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
redis_api_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_DB
|
||||
)
|
||||
|
||||
redis_api_usage_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_USAGE_DB
|
||||
)
|
||||
|
||||
# 添加API密钥验证
|
||||
API_KEY_NAME = "X-API-Key"
|
||||
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
||||
|
||||
async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)):
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=400, detail="API密钥缺失")
|
||||
return api_key
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
return None
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
return None
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
return None
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
# 更新总的token使用量
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
# 更新特定模型的token使用量
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||
base_tokens = 0
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||
|
||||
# 基础token:每MB文件大小消耗10个token
|
||||
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
||||
|
||||
if file_type == "image":
|
||||
img = Image.open(file_path)
|
||||
width, height = img.size
|
||||
pixel_count = width * height
|
||||
|
||||
# 图片token:每100个像素额外消耗5个token
|
||||
image_tokens = int((pixel_count / 10000) * 5)
|
||||
|
||||
base_tokens += image_tokens
|
||||
|
||||
elif file_type == "video":
|
||||
vr = VideoReader(file_path)
|
||||
fps = vr.get_avg_fps()
|
||||
frame_count = len(vr)
|
||||
width, height = vr[0].shape[1], vr[0].shape[0]
|
||||
|
||||
pixel_count = width * height * frame_count
|
||||
duration = frame_count / fps # 视频时长(秒)
|
||||
|
||||
# 视频token:每100万像素每秒额外消耗1个token
|
||||
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
||||
|
||||
base_tokens += video_tokens
|
||||
|
||||
return max(1, base_tokens) # 确保至少返回1个token
|
||||
except Exception as e:
|
||||
print(f"计算token时出错: {str(e)}")
|
||||
return 1 # 出错时返回默认值1
|
||||
|
||||
# 初始化模型
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
MODEL_PATH, torch_dtype="auto", device_map="cuda:1"
|
||||
)
|
||||
|
||||
min_pixels = 128*28*28
|
||||
max_pixels = 512*28*28
|
||||
processor = AutoProcessor.from_pretrained(MODEL_PATH, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
|
||||
class MediaAnalysisSystem:
|
||||
def __init__(self, model, processor):
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.MAX_NUM_FRAMES = 10
|
||||
|
||||
def encode_video(self, video_data):
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
return [l[int(i * gap + gap / 2)] for i in range(n)]
|
||||
|
||||
video_file = io.BytesIO(video_data)
|
||||
vr = VideoReader(video_file)
|
||||
sample_fps = round(vr.get_avg_fps() / 1)
|
||||
frame_idx = list(range(0, len(vr), sample_fps))
|
||||
if len(frame_idx) > self.MAX_NUM_FRAMES:
|
||||
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
||||
print('num frames:', len(frames))
|
||||
return frames
|
||||
|
||||
def process_media(self, media_data, object_name, media_type='image'):
|
||||
if not media_data:
|
||||
raise ValueError(f"Empty {media_type} data for {object_name}")
|
||||
|
||||
print(f"Processing {media_type}: {object_name}, data size: {len(media_data)} bytes")
|
||||
|
||||
if media_type == 'video':
|
||||
frames = self.encode_video(media_data)
|
||||
media_content = {"type": "video", "video": frames, "fps": 1.0}
|
||||
else: # image
|
||||
image = Image.open(io.BytesIO(media_data))
|
||||
media_content = {"type": "image", "image": image}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
media_content,
|
||||
{"type": "text", "text": "用中文尽可能详细地描述这个" + ("视频" if media_type == "video" else "图片") + ",包括场景、人物数量、行为变化等。"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
text = self.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to('cuda:1')
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=512)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
answer = self.processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
|
||||
extracted_info = self.extract_info(answer)
|
||||
|
||||
result = {
|
||||
"original_answer": answer,
|
||||
"extracted_info": extracted_info,
|
||||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
if media_type == 'video':
|
||||
result["num_frames"] = len(frames)
|
||||
|
||||
return result
|
||||
|
||||
def process_video(self, video_data, object_name):
|
||||
return self.process_media(video_data, object_name, media_type='video')
|
||||
|
||||
def process_image(self, image_data, object_name):
|
||||
return self.process_media(image_data, object_name, media_type='image')
|
||||
|
||||
@staticmethod
|
||||
def extract_time_from_filename(object_name):
|
||||
filename = os.path.basename(object_name)
|
||||
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
|
||||
|
||||
try:
|
||||
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
|
||||
end_time = start_time + timedelta(seconds=10)
|
||||
return start_time, end_time
|
||||
except ValueError:
|
||||
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
|
||||
return datetime.now(), datetime.now() + timedelta(seconds=10)
|
||||
|
||||
@staticmethod
|
||||
def extract_info(answer):
|
||||
info = {
|
||||
"environment": None,
|
||||
"num_people": None,
|
||||
"actions": [],
|
||||
"interactions": [],
|
||||
"objects": [],
|
||||
"furniture": []
|
||||
}
|
||||
|
||||
environments = ["办公室", "室内", "室外", "会议室"]
|
||||
for env in environments:
|
||||
if env in answer.lower():
|
||||
info["environment"] = env
|
||||
break
|
||||
|
||||
people_patterns = [
|
||||
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
||||
r'(男|女)(性|生|士)',
|
||||
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
||||
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
||||
r'(群众|民众|大众|公众)',
|
||||
r'(男女|老少|老幼|大人|小孩)'
|
||||
]
|
||||
for pattern in people_patterns:
|
||||
match = re.search(pattern, answer)
|
||||
if match:
|
||||
if match.group(1).isdigit():
|
||||
info["num_people"] = int(match.group(1))
|
||||
elif match.group(1) in ['一个', '一']:
|
||||
info["num_people"] = 1
|
||||
else:
|
||||
num_word_to_digit = {
|
||||
'二': 2, '三': 3, '四': 4, '五': 5,
|
||||
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
||||
}
|
||||
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
||||
break
|
||||
|
||||
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
||||
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
||||
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
||||
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
||||
|
||||
for item_list, key in [(actions, "actions"), (interactions, "interactions"), (objects, "objects"), (furniture, "furniture")]:
|
||||
for item in item_list:
|
||||
if item in answer:
|
||||
info[key].append(item)
|
||||
|
||||
return info
|
||||
|
||||
# 初始化 MediaAnalysisSystem
|
||||
media_analysis_system = MediaAnalysisSystem(model, processor)
|
||||
|
||||
async def process_file(file: UploadFile, file_type: str, api_key: str):
|
||||
content = await file.read()
|
||||
original_extension = os.path.splitext(file.filename)[1]
|
||||
|
||||
filename = f"qwenvl_{uuid.uuid4()}{original_extension}"
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 计算token
|
||||
tokens_required = calculate_tokens(file_path, file_type)
|
||||
|
||||
# 检查并更新token使用量
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新token使用量
|
||||
model_name = "Qwen2-VL-2B-Instruct"
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": filename,
|
||||
"type": file_type
|
||||
}).encode('utf-8'))
|
||||
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "文件已上传并排队等待处理",
|
||||
"filename": filename,
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
})
|
||||
|
||||
|
||||
@qwenvl_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
try:
|
||||
file_type = "image" if file.content_type.startswith("image") else "video"
|
||||
return await process_file(file, file_type, api_key)
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
@qwenvl_app.post("/analyze_video")
|
||||
async def analyze_video(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
try:
|
||||
return await process_file(file, "video", api_key)
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
@qwenvl_app.post("/analyze_image")
|
||||
async def analyze_image(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
try:
|
||||
return await process_file(file, "image", api_key)
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
try:
|
||||
if isinstance(message.value, dict):
|
||||
task = message.value
|
||||
else:
|
||||
task = json.loads(message.value.decode('utf-8'))
|
||||
|
||||
filename = task['filename']
|
||||
file_type = task['type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
file_data = f.read()
|
||||
|
||||
if file_type == "video":
|
||||
result = media_analysis_system.process_video(file_data, filename)
|
||||
elif file_type == "image":
|
||||
result = media_analysis_system.process_image(file_data, filename)
|
||||
|
||||
# 保存结果到 JSON 文件
|
||||
result_file_path = os.path.join(RESULT_DIR, f"{filename}.json")
|
||||
with open(result_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 将结果存储在 Redis 中
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"status": "completed",
|
||||
"result": result
|
||||
}))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
if 'filename' in locals() and 'file_type' in locals():
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
else:
|
||||
print("Error occurred before task details were extracted")
|
||||
|
||||
@qwenvl_app.get("/result/{filename}")
|
||||
async def get_result(filename: str, api_key: str = Depends(get_api_key)):
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
for file_type in ["video", "image"]:
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_json = json.loads(result)
|
||||
|
||||
if result_json.get("status") == "queued":
|
||||
return {"status": "queued", "message": "Your request is in the queue and will be processed soon."}
|
||||
elif result_json.get("status") == "processing":
|
||||
return {"status": "processing", "message": "Your request is being processed."}
|
||||
else:
|
||||
return result_json
|
||||
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
|
||||
async def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@5__:*_result:*')
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
print(f"Key changed: {key}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 在后台线程中启动Kafka消费者
|
||||
consumer_thread = threading.Thread(target=process_task, daemon=True)
|
||||
consumer_thread.start()
|
||||
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=7005)
|
||||
@@ -1,518 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, Security
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
from redis import Redis
|
||||
import io
|
||||
import re
|
||||
import torch
|
||||
from contextlib import asynccontextmanager
|
||||
import threading
|
||||
import string
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
qwenvl_app = FastAPI()
|
||||
app.mount("/qwenvl", qwenvl_app)
|
||||
torch.cuda.set_device(1)
|
||||
|
||||
# CORS设置
|
||||
ALLOWED_ORIGINS = ['https://beta.obscura.work']
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/models/qwen/Qwen2-VL-2B-Instruct"
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "qwenvl"
|
||||
KAFKA_GROUP_ID = "qwenvl_group"
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 8
|
||||
REDIS_API_DB = 12
|
||||
REDIS_API_USAGE_DB = 13
|
||||
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
redis_api_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_DB
|
||||
)
|
||||
|
||||
redis_api_usage_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_USAGE_DB
|
||||
)
|
||||
|
||||
# 定义API密钥头部
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
# 定义 base62 字符集
|
||||
BASE62 = string.digits + string.ascii_letters
|
||||
|
||||
# 验证API密钥的函数
|
||||
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||
if api_key and api_key.startswith("Bearer "):
|
||||
key = api_key.split(" ")[1]
|
||||
if key.startswith("obs-"):
|
||||
return key
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的API密钥",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
"api_key": api_key,
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
# 更新总的token使用量
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
# 更新特定模型的token使用量
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||
base_tokens = 0
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||
|
||||
# 基础token:每MB文件大小消耗10个token
|
||||
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
||||
|
||||
if file_type == "image":
|
||||
img = Image.open(file_path)
|
||||
width, height = img.size
|
||||
pixel_count = width * height
|
||||
|
||||
# 图片token:每100个像素额外消耗5个token
|
||||
image_tokens = int((pixel_count / 10000) * 5)
|
||||
|
||||
base_tokens += image_tokens
|
||||
|
||||
elif file_type == "video":
|
||||
vr = VideoReader(file_path)
|
||||
fps = vr.get_avg_fps()
|
||||
frame_count = len(vr)
|
||||
width, height = vr[0].shape[1], vr[0].shape[0]
|
||||
|
||||
pixel_count = width * height * frame_count
|
||||
duration = frame_count / fps # 视频时长(秒)
|
||||
|
||||
# 视频token:每100万像素每秒额外消耗1个token
|
||||
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
||||
|
||||
base_tokens += video_tokens
|
||||
|
||||
return max(1, base_tokens) # 确保至少返回1个token
|
||||
except Exception as e:
|
||||
print(f"计算token时出错: {str(e)}")
|
||||
return 1 # 出错时返回默认值1
|
||||
|
||||
# 初始化模型
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
MODEL_PATH, torch_dtype="auto", device_map="cuda:1"
|
||||
)
|
||||
|
||||
min_pixels = 128*28*28
|
||||
max_pixels = 512*28*28
|
||||
processor = AutoProcessor.from_pretrained(MODEL_PATH, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
|
||||
class MediaAnalysisSystem:
|
||||
def __init__(self, model, processor):
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.MAX_NUM_FRAMES = 10
|
||||
|
||||
def encode_video(self, video_data):
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
return [l[int(i * gap + gap / 2)] for i in range(n)]
|
||||
|
||||
video_file = io.BytesIO(video_data)
|
||||
vr = VideoReader(video_file)
|
||||
sample_fps = round(vr.get_avg_fps() / 1)
|
||||
frame_idx = list(range(0, len(vr), sample_fps))
|
||||
if len(frame_idx) > self.MAX_NUM_FRAMES:
|
||||
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
||||
print('num frames:', len(frames))
|
||||
return frames
|
||||
|
||||
def process_media(self, media_data, object_name, media_type='image'):
|
||||
if not media_data:
|
||||
raise ValueError(f"Empty {media_type} data for {object_name}")
|
||||
|
||||
print(f"Processing {media_type}: {object_name}, data size: {len(media_data)} bytes")
|
||||
|
||||
if media_type == 'video':
|
||||
frames = self.encode_video(media_data)
|
||||
media_content = {"type": "video", "video": frames, "fps": 1.0}
|
||||
else: # image
|
||||
image = Image.open(io.BytesIO(media_data))
|
||||
media_content = {"type": "image", "image": image}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
media_content,
|
||||
{"type": "text", "text": "用中文尽可能详细地描述这个" + ("视频" if media_type == "video" else "图片") + ",包括场景、人物数量、行为变化等。"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
text = self.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to('cuda:1')
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=512)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
answer = self.processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
|
||||
extracted_info = self.extract_info(answer)
|
||||
|
||||
result = {
|
||||
"original_answer": answer,
|
||||
"extracted_info": extracted_info,
|
||||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
if media_type == 'video':
|
||||
result["num_frames"] = len(frames)
|
||||
|
||||
return result
|
||||
|
||||
def process_video(self, video_data, object_name):
|
||||
return self.process_media(video_data, object_name, media_type='video')
|
||||
|
||||
def process_image(self, image_data, object_name):
|
||||
return self.process_media(image_data, object_name, media_type='image')
|
||||
|
||||
@staticmethod
|
||||
def extract_time_from_filename(object_name):
|
||||
filename = os.path.basename(object_name)
|
||||
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
|
||||
|
||||
try:
|
||||
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
|
||||
end_time = start_time + timedelta(seconds=10)
|
||||
return start_time, end_time
|
||||
except ValueError:
|
||||
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
|
||||
return datetime.now(), datetime.now() + timedelta(seconds=10)
|
||||
|
||||
@staticmethod
|
||||
def extract_info(answer):
|
||||
info = {
|
||||
"environment": None,
|
||||
"num_people": None,
|
||||
"actions": [],
|
||||
"interactions": [],
|
||||
"objects": [],
|
||||
"furniture": []
|
||||
}
|
||||
|
||||
environments = ["办公室", "室内", "室外", "会议室"]
|
||||
for env in environments:
|
||||
if env in answer.lower():
|
||||
info["environment"] = env
|
||||
break
|
||||
|
||||
people_patterns = [
|
||||
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
||||
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
||||
r'(男|女)(性|生|士)',
|
||||
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
||||
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
||||
r'(群众|民众|大众|公众)',
|
||||
r'(男女|老少|老幼|大人|小孩)'
|
||||
]
|
||||
for pattern in people_patterns:
|
||||
match = re.search(pattern, answer)
|
||||
if match:
|
||||
if match.group(1).isdigit():
|
||||
info["num_people"] = int(match.group(1))
|
||||
elif match.group(1) in ['一个', '一']:
|
||||
info["num_people"] = 1
|
||||
else:
|
||||
num_word_to_digit = {
|
||||
'二': 2, '三': 3, '四': 4, '五': 5,
|
||||
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
||||
}
|
||||
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
||||
break
|
||||
|
||||
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
||||
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
||||
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
||||
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
||||
|
||||
for item_list, key in [(actions, "actions"), (interactions, "interactions"), (objects, "objects"), (furniture, "furniture")]:
|
||||
for item in item_list:
|
||||
if item in answer:
|
||||
info[key].append(item)
|
||||
|
||||
return info
|
||||
|
||||
# 初始化 MediaAnalysisSystem
|
||||
media_analysis_system = MediaAnalysisSystem(model, processor)
|
||||
|
||||
async def process_file(file: UploadFile, file_type: str, api_key_info: dict):
|
||||
content = await file.read()
|
||||
original_extension = os.path.splitext(file.filename)[1]
|
||||
|
||||
filename = f"qwenvl_{uuid.uuid4()}{original_extension}"
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 计算token
|
||||
tokens_required = calculate_tokens(file_path, file_type)
|
||||
|
||||
# 检查并更新token使用量
|
||||
api_key = api_key_info['api_key']
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新token使用量
|
||||
model_name = "Qwen2-VL-2B-Instruct"
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": filename,
|
||||
"type": file_type
|
||||
}).encode('utf-8'))
|
||||
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "文件已上传并排队等待处理",
|
||||
"filename": filename,
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
})
|
||||
|
||||
|
||||
@qwenvl_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
try:
|
||||
file_type = "image" if file.content_type.startswith("image") else "video"
|
||||
return await process_file(file, file_type, api_key_info)
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
@qwenvl_app.post("/analyze_video")
|
||||
async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
api_key_info = await verify_api_key(api_key_info)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
try:
|
||||
return await process_file(file, "video", api_key_info)
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
@qwenvl_app.post("/analyze_image")
|
||||
async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
api_key_info = await verify_api_key(api_key_info)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
try:
|
||||
return await process_file(file, "image", api_key_info)
|
||||
except Exception as e:
|
||||
return JSONResponse(content={"error": str(e)}, status_code=500)
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
try:
|
||||
if isinstance(message.value, dict):
|
||||
task = message.value
|
||||
else:
|
||||
task = json.loads(message.value.decode('utf-8'))
|
||||
|
||||
filename = task['filename']
|
||||
file_type = task['type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
file_data = f.read()
|
||||
|
||||
if file_type == "video":
|
||||
result = media_analysis_system.process_video(file_data, filename)
|
||||
elif file_type == "image":
|
||||
result = media_analysis_system.process_image(file_data, filename)
|
||||
|
||||
# 保存结果到 JSON 文件
|
||||
result_file_path = os.path.join(RESULT_DIR, f"{filename}.json")
|
||||
with open(result_file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 将结果存储在 Redis 中
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"status": "completed",
|
||||
"result": result
|
||||
}))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
if 'filename' in locals() and 'file_type' in locals():
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
else:
|
||||
print("Error occurred before task details were extracted")
|
||||
|
||||
@qwenvl_app.get("/result/{filename}")
|
||||
async def get_result(filename: str, api_key: str = Depends(get_api_key)):
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
for file_type in ["video", "image"]:
|
||||
redis_key = f"{file_type}_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_json = json.loads(result)
|
||||
|
||||
if result_json.get("status") == "queued":
|
||||
return {"status": "queued", "message": "Your request is in the queue and will be processed soon."}
|
||||
elif result_json.get("status") == "processing":
|
||||
return {"status": "processing", "message": "Your request is being processed."}
|
||||
else:
|
||||
return result_json
|
||||
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
|
||||
async def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@5__:*_result:*')
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
print(f"Key changed: {key}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 在后台线程中启动Kafka消费者
|
||||
consumer_thread = threading.Thread(target=process_task, daemon=True)
|
||||
consumer_thread.start()
|
||||
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=7005)
|
||||
@@ -1,315 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from ultralytics import YOLO
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from redis import Redis
|
||||
import io
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
import threading
|
||||
import torch
|
||||
torch.cuda.set_device(1)
|
||||
import colorsys
|
||||
|
||||
app = FastAPI()
|
||||
yolo_app = FastAPI()
|
||||
app.mount("/yolo", yolo_app)
|
||||
|
||||
# CORS配置
|
||||
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
||||
|
||||
# 只为主应用添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/models/yolov8x.pt" # 请替换为您的模型路径
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "yolo" # 指定Kafka topic
|
||||
KAFKA_GROUP_ID = "yolo_group" # 指定消费者组ID
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 6
|
||||
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
class yoloDetector:
|
||||
def __init__(self, model_path):
|
||||
self.model = YOLO(model_path)
|
||||
|
||||
def detect(self, frame):
|
||||
results = self.model(frame)
|
||||
return results
|
||||
|
||||
def format_results(self, results, original_shape):
|
||||
formatted_results = []
|
||||
for r in results:
|
||||
boxes = r.boxes
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||
|
||||
# 缩放坐标到原始图像尺寸
|
||||
x1, x2 = [x * original_shape[1] / 640 for x in [x1, x2]]
|
||||
y1, y2 = [y * original_shape[0] / 640 for y in [y1, y2]]
|
||||
|
||||
conf = box.conf.item()
|
||||
cls = int(box.cls.item())
|
||||
name = self.model.names[cls]
|
||||
|
||||
formatted_results.append({
|
||||
"class": name,
|
||||
"confidence": conf,
|
||||
"bbox": [x1, y1, x2, y2]
|
||||
})
|
||||
return formatted_results
|
||||
|
||||
def draw_results(self, frame, formatted_results):
|
||||
for result in formatted_results:
|
||||
x1, y1, x2, y2 = map(int, result['bbox'])
|
||||
name = result['class']
|
||||
conf = result['confidence']
|
||||
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) # 使用固定的绿色
|
||||
label = f"{name} {conf:.2f}"
|
||||
(text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
|
||||
cv2.rectangle(frame, (x1, y1 - text_height - 5), (x1 + text_width, y1), (0, 255, 0), -1)
|
||||
cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
|
||||
|
||||
return frame
|
||||
|
||||
detector = yoloDetector(MODEL_PATH)
|
||||
|
||||
def process_image(image_data, filename):
|
||||
try:
|
||||
original_img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
original_shape = original_img.shape
|
||||
|
||||
img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(img, (640, 640))
|
||||
img = img.transpose((2, 0, 1))
|
||||
img = np.ascontiguousarray(img)
|
||||
img = torch.from_numpy(img).float()
|
||||
img /= 255.0
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
results = detector.detect(img)
|
||||
|
||||
json_results = detector.format_results(results, original_shape)
|
||||
|
||||
annotated_img = detector.draw_results(original_img, json_results)
|
||||
|
||||
annotated_filename = f"yolo_{filename}"
|
||||
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
cv2.imwrite(annotated_path, annotated_img)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"处理图像时出错: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def process_video(video_data, filename):
|
||||
try:
|
||||
temp_video_path = os.path.join(UPLOAD_DIR, f"yolo_{filename}")
|
||||
with open(temp_video_path, 'wb') as temp_video:
|
||||
temp_video.write(video_data)
|
||||
|
||||
cap = cv2.VideoCapture(temp_video_path)
|
||||
frame_count = 0
|
||||
json_results = []
|
||||
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
original_shape = (height, width)
|
||||
|
||||
annotated_filename = f"yolo_{filename}"
|
||||
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 1, (width, height))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# 每秒只处理一帧
|
||||
if frame_count % fps == 0:
|
||||
preprocessed_frame = preprocess_frame(frame)
|
||||
|
||||
results = detector.detect(preprocessed_frame)
|
||||
frame_json_results = detector.format_results(results, original_shape)
|
||||
json_results.append({"frame": frame_count, "detections": frame_json_results})
|
||||
|
||||
annotated_frame = detector.draw_results(frame, frame_json_results)
|
||||
out.write(annotated_frame)
|
||||
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
os.remove(temp_video_path)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"处理视频时出错: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def preprocess_frame(frame):
|
||||
# 预处理单个视频帧
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame_resized = cv2.resize(frame_rgb, (640, 640)) # 调整为YOLO输入尺寸
|
||||
frame_transposed = frame_resized.transpose((2, 0, 1)) # HWC转为CHW
|
||||
frame_contiguous = np.ascontiguousarray(frame_transposed)
|
||||
frame_tensor = torch.from_numpy(frame_contiguous).float()
|
||||
frame_normalized = frame_tensor / 255.0 # 归一化到[0, 1]
|
||||
frame_batched = frame_normalized.unsqueeze(0) # 添加批次维度
|
||||
return frame_batched
|
||||
|
||||
@yolo_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# Save the original file
|
||||
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(original_file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Send processing task to Kafka
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": new_filename,
|
||||
"file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
}).encode('utf-8'))
|
||||
|
||||
# Set initial status in Redis
|
||||
redis_key = f"yolo_result:{new_filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
return JSONResponse(content={"message": "File uploaded and queued for processing", "filename": new_filename})
|
||||
|
||||
@yolo_app.get("/result/{filename}")
|
||||
async def get_yolo_result(filename: str):
|
||||
redis_key = f"yolo_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
return JSONResponse(content=result_data) # 直接返回整个结果,包括 status
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
@yolo_app.get("/annotated/{filename}")
|
||||
async def get_annotated_file(filename: str):
|
||||
redis_key = f"yolo_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
if result_data["status"] == "completed":
|
||||
annotated_filename = result_data["annotated_filename"]
|
||||
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
if os.path.exists(file_path):
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
||||
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Annotated file not found")
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
task = message.value
|
||||
filename = task['filename']
|
||||
file_type = task['file_type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
# Update status to "processing"
|
||||
redis_key = f"yolo_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
try:
|
||||
if file_type == "image":
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_image(content, filename)
|
||||
else:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_video(content, filename)
|
||||
|
||||
if json_results and annotated_filename:
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"json_results": json_results,
|
||||
"status": "completed",
|
||||
"annotated_filename": annotated_filename
|
||||
}))
|
||||
else:
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
|
||||
def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@3__:yolo_result:*') # 监听所有yolo_result键的变化
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
operation = message['data'].decode('utf-8')
|
||||
|
||||
if operation == 'set':
|
||||
value = redis_client.get(f"yolo_result:{key}")
|
||||
if value:
|
||||
result = json.loads(value)
|
||||
print(f"Status update for {key}: {result['status']}")
|
||||
|
||||
# 这里可以添加其他处理逻辑,比如发送通知等
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动处理任务的线程
|
||||
threading.Thread(target=process_task, daemon=True).start()
|
||||
|
||||
# 启动Redis监听线程
|
||||
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=7003)
|
||||
@@ -1,462 +0,0 @@
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Header
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from redis import Redis
|
||||
from ultralytics import YOLO
|
||||
import json
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import threading
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
yolo_app = FastAPI()
|
||||
app.mount("/yolo", yolo_app)
|
||||
|
||||
# CORS配置
|
||||
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
||||
|
||||
# 只为主应用添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/models/yolov8x.pt" # 请替换为您的模型路径
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "yolo" # 指定Kafka topic
|
||||
KAFKA_GROUP_ID = "yolo_group" # 指定消费者组ID
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 6
|
||||
REDIS_API_DB = 12
|
||||
REDIS_API_USAGE_DB = 13
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
redis_api_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_DB
|
||||
)
|
||||
redis_api_usage_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_USAGE_DB
|
||||
)
|
||||
|
||||
# 添加API密钥验证
|
||||
API_KEY_NAME = "X-API-Key"
|
||||
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
||||
|
||||
async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)):
|
||||
if api_key is None:
|
||||
raise HTTPException(status_code=400, detail="API密钥缺失")
|
||||
return api_key
|
||||
|
||||
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||
base_tokens = 0
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||
|
||||
# 基础token:每MB文件大小消耗10个token
|
||||
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
||||
|
||||
if file_type == "image":
|
||||
img = cv2.imread(file_path)
|
||||
if img is None:
|
||||
raise ValueError("无法读取图片文件")
|
||||
height, width = img.shape[:2]
|
||||
pixel_count = height * width
|
||||
|
||||
# 图片token:每100个像素额外消耗5个token
|
||||
image_tokens = int((pixel_count / 10000) * 5)
|
||||
|
||||
base_tokens += image_tokens
|
||||
|
||||
elif file_type == "video":
|
||||
cap = cv2.VideoCapture(file_path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError("无法打开视频文件")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
cap.release()
|
||||
|
||||
pixel_count = width * height * frame_count
|
||||
duration = frame_count / fps # 视频时长(秒)
|
||||
|
||||
# 视频token:每100万像素每秒额外消耗1个token
|
||||
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
||||
|
||||
base_tokens += video_tokens
|
||||
|
||||
return max(1, base_tokens) # 确保至少返回1个token
|
||||
except Exception as e:
|
||||
print(f"计算token时出错: {str(e)}")
|
||||
return 1 # 出错时返回默认值1
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
return None
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
return None
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
return None
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
# 更新总的token使用量
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
# 更新特定模型的token使用量
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
class yoloDetector:
|
||||
def __init__(self, model_path):
|
||||
self.model = YOLO(model_path)
|
||||
|
||||
def detect(self, frame):
|
||||
results = self.model(frame)
|
||||
return results
|
||||
|
||||
def format_results(self, results, original_shape):
|
||||
formatted_results = []
|
||||
for r in results:
|
||||
boxes = r.boxes
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||
|
||||
# 缩放坐标到原始图像尺寸
|
||||
x1, x2 = [x * original_shape[1] / 640 for x in [x1, x2]]
|
||||
y1, y2 = [y * original_shape[0] / 640 for y in [y1, y2]]
|
||||
|
||||
conf = box.conf.item()
|
||||
cls = int(box.cls.item())
|
||||
name = self.model.names[cls]
|
||||
|
||||
formatted_results.append({
|
||||
"class": name,
|
||||
"confidence": conf,
|
||||
"bbox": [x1, y1, x2, y2]
|
||||
})
|
||||
return formatted_results
|
||||
|
||||
def draw_results(self, frame, formatted_results):
|
||||
for result in formatted_results:
|
||||
x1, y1, x2, y2 = map(int, result['bbox'])
|
||||
name = result['class']
|
||||
conf = result['confidence']
|
||||
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) # 使用固定的绿色
|
||||
label = f"{name} {conf:.2f}"
|
||||
(text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
|
||||
cv2.rectangle(frame, (x1, y1 - text_height - 5), (x1 + text_width, y1), (0, 255, 0), -1)
|
||||
cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
|
||||
|
||||
return frame
|
||||
|
||||
detector = yoloDetector(MODEL_PATH)
|
||||
|
||||
def process_image(image_data, filename):
|
||||
try:
|
||||
original_img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
original_shape = original_img.shape
|
||||
|
||||
img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(img, (640, 640))
|
||||
img = img.transpose((2, 0, 1))
|
||||
img = np.ascontiguousarray(img)
|
||||
img = torch.from_numpy(img).float()
|
||||
img /= 255.0
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
results = detector.detect(img)
|
||||
|
||||
json_results = detector.format_results(results, original_shape)
|
||||
|
||||
annotated_img = detector.draw_results(original_img, json_results)
|
||||
|
||||
annotated_filename = f"yolo_{filename}"
|
||||
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
cv2.imwrite(annotated_path, annotated_img)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"处理图像时出错: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def process_video(video_data, filename):
|
||||
try:
|
||||
temp_video_path = os.path.join(UPLOAD_DIR, f"yolo_{filename}")
|
||||
with open(temp_video_path, 'wb') as temp_video:
|
||||
temp_video.write(video_data)
|
||||
|
||||
cap = cv2.VideoCapture(temp_video_path)
|
||||
frame_count = 0
|
||||
json_results = []
|
||||
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
original_shape = (height, width)
|
||||
|
||||
annotated_filename = f"yolo_{filename}"
|
||||
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 1, (width, height))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# 每秒只处理一帧
|
||||
if frame_count % fps == 0:
|
||||
preprocessed_frame = preprocess_frame(frame)
|
||||
|
||||
results = detector.detect(preprocessed_frame)
|
||||
frame_json_results = detector.format_results(results, original_shape)
|
||||
json_results.append({"frame": frame_count, "detections": frame_json_results})
|
||||
|
||||
annotated_frame = detector.draw_results(frame, frame_json_results)
|
||||
out.write(annotated_frame)
|
||||
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
os.remove(temp_video_path)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"处理视频时出错: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def preprocess_frame(frame):
|
||||
# 预处理单个视频帧
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame_resized = cv2.resize(frame_rgb, (640, 640)) # 调整为YOLO输入尺寸
|
||||
frame_transposed = frame_resized.transpose((2, 0, 1)) # HWC转为CHW
|
||||
frame_contiguous = np.ascontiguousarray(frame_transposed)
|
||||
frame_tensor = torch.from_numpy(frame_contiguous).float()
|
||||
frame_normalized = frame_tensor / 255.0 # 归一化到[0, 1]
|
||||
frame_batched = frame_normalized.unsqueeze(0) # 添加批次维度
|
||||
return frame_batched
|
||||
|
||||
@yolo_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...),api_key: str = Depends(get_api_key)):
|
||||
# 验证 API key
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# 保存原始文件
|
||||
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(original_file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 计算 token
|
||||
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
tokens_required = calculate_tokens(original_file_path, file_type)
|
||||
if tokens_required is None or tokens_required <= 0:
|
||||
raise HTTPException(status_code=500, detail="无法计算所需的token数量")
|
||||
|
||||
# 检查并更新 token 使用量
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
model_name = "yolov8x"
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
|
||||
# 发送处理任务到 Kafka
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": new_filename,
|
||||
"file_type": file_type
|
||||
}).encode('utf-8'))
|
||||
|
||||
# 在 Redis 中设置初始状态
|
||||
redis_key = f"yolo_result:{new_filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "文件已上传并排队等待处理",
|
||||
"filename": new_filename,
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
})
|
||||
@yolo_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_yolo_result(filename: str):
|
||||
redis_key = f"yolo_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
return JSONResponse(content=result_data) # 直接返回整个结果,包括 status
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
@yolo_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_annotated_file(filename: str):
|
||||
redis_key = f"yolo_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
if result_data["status"] == "completed":
|
||||
annotated_filename = result_data["annotated_filename"]
|
||||
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
if os.path.exists(file_path):
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
||||
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Annotated file not found")
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
task = message.value
|
||||
filename = task['filename']
|
||||
file_type = task['file_type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
# Update status to "processing"
|
||||
redis_key = f"yolo_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
try:
|
||||
if file_type == "image":
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_image(content, filename)
|
||||
else:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_video(content, filename)
|
||||
|
||||
if json_results and annotated_filename:
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"json_results": json_results,
|
||||
"status": "completed",
|
||||
"annotated_filename": annotated_filename
|
||||
}))
|
||||
else:
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
|
||||
def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@3__:yolo_result:*') # 监听所有yolo_result键的变化
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
operation = message['data'].decode('utf-8')
|
||||
|
||||
if operation == 'set':
|
||||
value = redis_client.get(f"yolo_result:{key}")
|
||||
if value:
|
||||
result = json.loads(value)
|
||||
print(f"Status update for {key}: {result['status']}")
|
||||
|
||||
# 这里可以添加其他处理逻辑,比如发送通知等
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动处理任务的线程
|
||||
threading.Thread(target=process_task, daemon=True).start()
|
||||
|
||||
# 启动Redis监听线程
|
||||
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=7003)
|
||||
@@ -1,472 +0,0 @@
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from redis import Redis
|
||||
from ultralytics import YOLO
|
||||
import json
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import threading
|
||||
import string
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
yolo_app = FastAPI()
|
||||
app.mount("/yolo", yolo_app)
|
||||
|
||||
# CORS配置
|
||||
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
||||
|
||||
# 只为主应用添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
MODEL_PATH = "/home/zydi/models/yolov8x.pt" # 请替换为您的模型路径
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_TOPIC = "yolo" # 指定Kafka topic
|
||||
KAFKA_GROUP_ID = "yolo_group" # 指定消费者组ID
|
||||
|
||||
REDIS_HOST = "222.186.10.253"
|
||||
REDIS_PORT = 6379
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
REDIS_DB = 6
|
||||
REDIS_API_DB = 12
|
||||
REDIS_API_USAGE_DB = 13
|
||||
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
||||
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化 Kafka
|
||||
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
redis_api_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_DB
|
||||
)
|
||||
redis_api_usage_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_USAGE_DB
|
||||
)
|
||||
|
||||
# 定义API密钥头部
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
# 定义 base62 字符集
|
||||
BASE62 = string.digits + string.ascii_letters
|
||||
|
||||
# 验证API密钥的函数
|
||||
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||
if api_key and api_key.startswith("Bearer "):
|
||||
key = api_key.split(" ")[1]
|
||||
if key.startswith("obs-"):
|
||||
return key
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的API密钥",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||
base_tokens = 0
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||
|
||||
# 基础token:每MB文件大小消耗10个token
|
||||
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
||||
|
||||
if file_type == "image":
|
||||
img = cv2.imread(file_path)
|
||||
if img is None:
|
||||
raise ValueError("无法读取图片文件")
|
||||
height, width = img.shape[:2]
|
||||
pixel_count = height * width
|
||||
|
||||
# 图片token:每100个像素额外消耗5个token
|
||||
image_tokens = int((pixel_count / 10000) * 5)
|
||||
|
||||
base_tokens += image_tokens
|
||||
|
||||
elif file_type == "video":
|
||||
cap = cv2.VideoCapture(file_path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError("无法打开视频文件")
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
cap.release()
|
||||
|
||||
pixel_count = width * height * frame_count
|
||||
duration = frame_count / fps # 视频时长(秒)
|
||||
|
||||
# 视频token:每100万像素每秒额外消耗1个token
|
||||
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
||||
|
||||
base_tokens += video_tokens
|
||||
|
||||
return max(1, base_tokens) # 确保至少返回1个token
|
||||
except Exception as e:
|
||||
print(f"计算token时出错: {str(e)}")
|
||||
return 1 # 出错时返回默认值1
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
"api_key": api_key,
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
# 更新总的token使用量
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
# 更新特定模型的token使用量
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
class yoloDetector:
|
||||
def __init__(self, model_path):
|
||||
self.model = YOLO(model_path)
|
||||
|
||||
def detect(self, frame):
|
||||
results = self.model(frame)
|
||||
return results
|
||||
|
||||
def format_results(self, results, original_shape):
|
||||
formatted_results = []
|
||||
for r in results:
|
||||
boxes = r.boxes
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].tolist()
|
||||
|
||||
# 缩放坐标到原始图像尺寸
|
||||
x1, x2 = [x * original_shape[1] / 640 for x in [x1, x2]]
|
||||
y1, y2 = [y * original_shape[0] / 640 for y in [y1, y2]]
|
||||
|
||||
conf = box.conf.item()
|
||||
cls = int(box.cls.item())
|
||||
name = self.model.names[cls]
|
||||
|
||||
formatted_results.append({
|
||||
"class": name,
|
||||
"confidence": conf,
|
||||
"bbox": [x1, y1, x2, y2]
|
||||
})
|
||||
return formatted_results
|
||||
|
||||
def draw_results(self, frame, formatted_results):
|
||||
for result in formatted_results:
|
||||
x1, y1, x2, y2 = map(int, result['bbox'])
|
||||
name = result['class']
|
||||
conf = result['confidence']
|
||||
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) # 使用固定的绿色
|
||||
label = f"{name} {conf:.2f}"
|
||||
(text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
|
||||
cv2.rectangle(frame, (x1, y1 - text_height - 5), (x1 + text_width, y1), (0, 255, 0), -1)
|
||||
cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
|
||||
|
||||
return frame
|
||||
|
||||
detector = yoloDetector(MODEL_PATH)
|
||||
|
||||
def process_image(image_data, filename):
|
||||
try:
|
||||
original_img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
||||
original_shape = original_img.shape
|
||||
|
||||
img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(img, (640, 640))
|
||||
img = img.transpose((2, 0, 1))
|
||||
img = np.ascontiguousarray(img)
|
||||
img = torch.from_numpy(img).float()
|
||||
img /= 255.0
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
results = detector.detect(img)
|
||||
|
||||
json_results = detector.format_results(results, original_shape)
|
||||
|
||||
annotated_img = detector.draw_results(original_img, json_results)
|
||||
|
||||
annotated_filename = f"yolo_{filename}"
|
||||
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
cv2.imwrite(annotated_path, annotated_img)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"处理图像时出错: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def process_video(video_data, filename):
|
||||
try:
|
||||
temp_video_path = os.path.join(UPLOAD_DIR, f"yolo_{filename}")
|
||||
with open(temp_video_path, 'wb') as temp_video:
|
||||
temp_video.write(video_data)
|
||||
|
||||
cap = cv2.VideoCapture(temp_video_path)
|
||||
frame_count = 0
|
||||
json_results = []
|
||||
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
original_shape = (height, width)
|
||||
|
||||
annotated_filename = f"yolo_{filename}"
|
||||
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 1, (width, height))
|
||||
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# 每秒只处理一帧
|
||||
if frame_count % fps == 0:
|
||||
preprocessed_frame = preprocess_frame(frame)
|
||||
|
||||
results = detector.detect(preprocessed_frame)
|
||||
frame_json_results = detector.format_results(results, original_shape)
|
||||
json_results.append({"frame": frame_count, "detections": frame_json_results})
|
||||
|
||||
annotated_frame = detector.draw_results(frame, frame_json_results)
|
||||
out.write(annotated_frame)
|
||||
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
|
||||
os.remove(temp_video_path)
|
||||
|
||||
return json_results, annotated_filename
|
||||
except Exception as e:
|
||||
print(f"处理视频时出错: {str(e)}")
|
||||
return None, None
|
||||
|
||||
def preprocess_frame(frame):
|
||||
# 预处理单个视频帧
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame_resized = cv2.resize(frame_rgb, (640, 640)) # 调整为YOLO输入尺寸
|
||||
frame_transposed = frame_resized.transpose((2, 0, 1)) # HWC转为CHW
|
||||
frame_contiguous = np.ascontiguousarray(frame_transposed)
|
||||
frame_tensor = torch.from_numpy(frame_contiguous).float()
|
||||
frame_normalized = frame_tensor / 255.0 # 归一化到[0, 1]
|
||||
frame_batched = frame_normalized.unsqueeze(0) # 添加批次维度
|
||||
return frame_batched
|
||||
|
||||
@yolo_app.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# 保存原始文件
|
||||
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(original_file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 计算 token
|
||||
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
tokens_required = calculate_tokens(original_file_path, file_type)
|
||||
if tokens_required is None or tokens_required <= 0:
|
||||
raise HTTPException(status_code=500, detail="无法计算所需的token数量")
|
||||
|
||||
# 检查并更新 token 使用量
|
||||
api_key = api_key_info['api_key']
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
model_name = "yolov8x"
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
|
||||
# 发送处理任务到 Kafka
|
||||
producer.send(KAFKA_TOPIC, json.dumps({
|
||||
"filename": new_filename,
|
||||
"file_type": file_type
|
||||
}).encode('utf-8'))
|
||||
|
||||
# 在 Redis 中设置初始状态
|
||||
redis_key = f"yolo_result:{new_filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return JSONResponse(content={
|
||||
"message": "文件已上传并排队等待处理",
|
||||
"filename": new_filename,
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
})
|
||||
|
||||
|
||||
@yolo_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_yolo_result(filename: str):
|
||||
redis_key = f"yolo_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
return JSONResponse(content=result_data) # 直接返回整个结果,包括 status
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Result not found")
|
||||
|
||||
@yolo_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)])
|
||||
async def get_annotated_file(filename: str):
|
||||
redis_key = f"yolo_result:{filename}"
|
||||
result = redis_client.get(redis_key)
|
||||
if result:
|
||||
result_data = json.loads(result)
|
||||
if result_data["status"] == "completed":
|
||||
annotated_filename = result_data["annotated_filename"]
|
||||
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
||||
if os.path.exists(file_path):
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
||||
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Annotated file not found")
|
||||
|
||||
def process_task():
|
||||
for message in consumer:
|
||||
task = message.value
|
||||
filename = task['filename']
|
||||
file_type = task['file_type']
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||||
|
||||
# Update status to "processing"
|
||||
redis_key = f"yolo_result:{filename}"
|
||||
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
||||
|
||||
try:
|
||||
if file_type == "image":
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_image(content, filename)
|
||||
else:
|
||||
with open(file_path, 'rb') as f:
|
||||
content = f.read()
|
||||
json_results, annotated_filename = process_video(content, filename)
|
||||
|
||||
if json_results and annotated_filename:
|
||||
redis_client.set(redis_key, json.dumps({
|
||||
"json_results": json_results,
|
||||
"status": "completed",
|
||||
"annotated_filename": annotated_filename
|
||||
}))
|
||||
else:
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
||||
except Exception as e:
|
||||
print(f"Error processing task: {str(e)}")
|
||||
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
||||
|
||||
def listen_redis_changes():
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.psubscribe('__keyspace@3__:yolo_result:*') # 监听所有yolo_result键的变化
|
||||
|
||||
for message in pubsub.listen():
|
||||
if message['type'] == 'pmessage':
|
||||
key = message['channel'].decode('utf-8').split(':')[-1]
|
||||
operation = message['data'].decode('utf-8')
|
||||
|
||||
if operation == 'set':
|
||||
value = redis_client.get(f"yolo_result:{key}")
|
||||
if value:
|
||||
result = json.loads(value)
|
||||
print(f"Status update for {key}: {result['status']}")
|
||||
|
||||
# 这里可以添加其他处理逻辑,比如发送通知等
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动处理任务的线程
|
||||
threading.Thread(target=process_task, daemon=True).start()
|
||||
|
||||
# 启动Redis监听线程
|
||||
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=7003)
|
||||
@@ -1,81 +0,0 @@
|
||||
# config.py
|
||||
|
||||
import os
|
||||
|
||||
# Kafka配置
|
||||
KAFKA_BROKER = "222.186.10.253:9092"
|
||||
KAFKA_GROUP_ID_PREFIX = "group"
|
||||
|
||||
# Redis配置
|
||||
REDIS_HOST = "150.158.144.159"
|
||||
REDIS_PORT = 13003
|
||||
REDIS_PASSWORD = "Obscura@2024"
|
||||
MAIN_REDIS_DB = 0
|
||||
REDIS_API_DB = 2
|
||||
REDIS_API_USAGE_DB = 3
|
||||
# 目录配置
|
||||
UPLOAD_DIR = "/obscura/task/upload"
|
||||
RESULT_DIR = "/obscura/task/result"
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 模型配置
|
||||
YOLO_MODEL_PATH = "/obscura/models/yolov8x.pt"
|
||||
POSE_MODEL_PATH = "/obscura/models/yolov8x-pose.pt"
|
||||
QWEN_MODEL_PATH = "/obscura/models/qwen/Qwen2-VL-2B-Instruct"
|
||||
FALL_MODEL_PATH = "/obscura/models/yolov8n-fall.pt"
|
||||
FACE_MODEL_PATH = "/obscura/models/yolov8n-face.pt"
|
||||
MEDIAPIPE_MODEL_PATH = "/obscura/models/face_landmarker.task"
|
||||
# COMPARE_MODEL_PATH = "/obscura/models/insightface/insw_r100_glint360k.onnx"
|
||||
# Ollama配置
|
||||
OLLAMA_URL = "http://127.0.0.1:11434/api/generate"
|
||||
|
||||
# 各个worker的配置
|
||||
WORKER_CONFIGS = {
|
||||
"yolo": {
|
||||
"kafka_topic": "yolo",
|
||||
"redis_db": 4,
|
||||
},
|
||||
"pose": {
|
||||
"kafka_topic": "pose",
|
||||
"redis_db": 5,
|
||||
},
|
||||
"qwenvl": {
|
||||
"kafka_topic": "qwenvl",
|
||||
"redis_db": 9,
|
||||
},
|
||||
"qwenvl_analyze": {
|
||||
"kafka_topic": "qwenvl_analyze",
|
||||
"redis_db": 32,
|
||||
},
|
||||
"cpm": {
|
||||
"kafka_topic": "cpm",
|
||||
"redis_db": 8,
|
||||
},
|
||||
"cpm_analyze": {
|
||||
"kafka_topic": "cpm_analyze",
|
||||
"redis_db": 31,
|
||||
},
|
||||
"fall": {
|
||||
"kafka_topic": "fall",
|
||||
"redis_db": 6,
|
||||
},
|
||||
"face": {
|
||||
"kafka_topic": "face",
|
||||
"redis_db": 7,
|
||||
},
|
||||
"mediapipe": {
|
||||
"kafka_topic": "mediapipe",
|
||||
"redis_db": 10,
|
||||
},
|
||||
"compare": {
|
||||
"kafka_topic": "compare",
|
||||
"redis_db": 30,
|
||||
}
|
||||
}
|
||||
|
||||
# GPU设置
|
||||
CUDA_DEVICE_0 = "cuda:0"
|
||||
CUDA_DEVICE_1 = "cuda:1"
|
||||
@@ -1,419 +0,0 @@
|
||||
# main.py
|
||||
from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Security
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from kafka import KafkaProducer
|
||||
from redis import Redis
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import string
|
||||
from decord import VideoReader
|
||||
from PIL import Image
|
||||
from fastapi.responses import FileResponse
|
||||
import logging
|
||||
from config import *
|
||||
|
||||
app = FastAPI()
|
||||
v1_app = FastAPI()
|
||||
app.mount("/v1", v1_app)
|
||||
|
||||
|
||||
# CORS设置
|
||||
# ALLOWED_ORIGINS = ['https://beta.obscura.work']
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
KAFKA_BROKER = KAFKA_BROKER
|
||||
REDIS_HOST = REDIS_HOST
|
||||
REDIS_PORT = REDIS_PORT
|
||||
REDIS_PASSWORD = REDIS_PASSWORD
|
||||
REDIS_DB = MAIN_REDIS_DB
|
||||
REDIS_API_DB = REDIS_API_DB
|
||||
REDIS_API_USAGE_DB = REDIS_API_USAGE_DB
|
||||
UPLOAD_DIR = UPLOAD_DIR
|
||||
RESULT_DIR = RESULT_DIR
|
||||
MAX_FILE_AGE = timedelta(hours=1)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||
|
||||
# 定义支持的任务类型
|
||||
KAFKA_TOPICS = {
|
||||
'pose': 'pose',
|
||||
'mediapipe': 'mediapipe',
|
||||
'qwenvl': 'qwenvl',
|
||||
'yolo': 'yolo',
|
||||
'fall': 'fall',
|
||||
'face': 'face',
|
||||
'cpm': 'cpm'
|
||||
}
|
||||
|
||||
TASK_TYPES = list(KAFKA_TOPICS.keys())
|
||||
|
||||
|
||||
# 初始化 Kafka Producer
|
||||
producer = KafkaProducer(
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_DB
|
||||
)
|
||||
|
||||
redis_api_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_DB
|
||||
)
|
||||
|
||||
redis_api_usage_client = Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
password=REDIS_PASSWORD,
|
||||
db=REDIS_API_USAGE_DB
|
||||
)
|
||||
redis_pose_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['pose']['redis_db'])
|
||||
redis_cpm_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm']['redis_db'])
|
||||
redis_yolo_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['yolo']['redis_db'])
|
||||
redis_face_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['face']['redis_db'])
|
||||
redis_fall_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['fall']['redis_db'])
|
||||
redis_mediapipe_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['mediapipe']['redis_db'])
|
||||
redis_qwenvl_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl']['redis_db'])
|
||||
|
||||
@v1_app.get('/favicon.ico', include_in_schema=False)
|
||||
async def favicon():
|
||||
file_name = "favicon.ico"
|
||||
file_path = os.path.join(app.root_path, "static", file_name)
|
||||
if os.path.isfile(file_path):
|
||||
return FileResponse(file_path)
|
||||
else:
|
||||
return {"message": "Favicon not found"}, 404
|
||||
|
||||
# 定义API密钥头部
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
# 定义 base62 字符集
|
||||
BASE62 = string.digits + string.ascii_letters
|
||||
|
||||
# 验证API密钥的函数
|
||||
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||
if api_key and api_key.startswith("Bearer "):
|
||||
key = api_key.split(" ")[1]
|
||||
if key.startswith("obs-"):
|
||||
return key
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的API密钥",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
logging.info(f"验证API密钥: {api_key}")
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
logging.warning(f"API密钥不存在: {api_key}")
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
logging.warning(f"API密钥已停用: {api_key}")
|
||||
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
logging.warning(f"API密钥已过期: {api_key}")
|
||||
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
logging.info(f"API密钥验证成功: {api_key}")
|
||||
return {
|
||||
"api_key": api_key,
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
# 更新总的token使用量
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
# 更新特定模型的token使用量
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||
base_tokens = 0
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||
|
||||
# 基础token:每MB文件大小消耗10个token
|
||||
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
||||
|
||||
if file_type == "image":
|
||||
img = Image.open(file_path)
|
||||
width, height = img.size
|
||||
pixel_count = width * height
|
||||
|
||||
# 图片token:每100个像素额外消耗5个token
|
||||
image_tokens = int((pixel_count / 10000) * 5)
|
||||
|
||||
base_tokens += image_tokens
|
||||
|
||||
elif file_type == "video":
|
||||
vr = VideoReader(file_path)
|
||||
fps = vr.get_avg_fps()
|
||||
frame_count = len(vr)
|
||||
width, height = vr[0].shape[1], vr[0].shape[0]
|
||||
|
||||
pixel_count = width * height * frame_count
|
||||
duration = frame_count / fps # 视频时长(秒)
|
||||
|
||||
# 视频token:每100万像素每秒额外消耗1个token
|
||||
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
||||
|
||||
base_tokens += video_tokens
|
||||
|
||||
return max(1, base_tokens) # 确保至少返回1个token
|
||||
except Exception as e:
|
||||
print(f"计算token时出错: {str(e)}")
|
||||
return 1 # 出错时返回默认值1
|
||||
|
||||
|
||||
|
||||
async def upload_file(file: UploadFile, task_type: str, api_key_info: dict):
|
||||
if task_type not in KAFKA_TOPICS:
|
||||
raise HTTPException(status_code=400, detail="不支持的任务类型")
|
||||
|
||||
content = await file.read()
|
||||
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 计算 token
|
||||
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||
tokens_required = calculate_tokens(file_path, file_type)
|
||||
|
||||
# 检查并更新 token 使用量
|
||||
api_key = api_key_info['api_key']
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
await update_token_usage(api_key, tokens_required, task_type)
|
||||
|
||||
# 创建任务记录
|
||||
task_id = str(uuid.uuid4())
|
||||
task_data = {
|
||||
"task_id": task_id,
|
||||
"filename": new_filename,
|
||||
"file_type": file_type,
|
||||
"task_type": task_type,
|
||||
"status": "queued",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# 存储任务信息到Redis
|
||||
redis_client.set(f"task:{task_id}", json.dumps(task_data))
|
||||
logging.info(f"任务信息已存储到Redis: {task_id}")
|
||||
|
||||
# 发送任务到对应的Kafka主题
|
||||
kafka_topic = KAFKA_TOPICS[task_type]
|
||||
producer.send(kafka_topic, task_data)
|
||||
logging.info(f"任务已发送到Kafka主题: {kafka_topic}")
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{task_type}_tokens_used", 0))
|
||||
|
||||
response_data = {
|
||||
"message": "文件已上传并排队等待处理",
|
||||
"task_id": task_id,
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{task_type}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
}
|
||||
logging.info(f"上传文件完成: {task_id}")
|
||||
return JSONResponse(content=response_data)
|
||||
|
||||
# 为每个任务类型创建单独的端点
|
||||
@v1_app.post("/pose")
|
||||
async def upload_pose(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
logging.info(f"收到 /pose端点的请求")
|
||||
return await upload_file(file, task_type="pose", api_key_info=api_key_info)
|
||||
|
||||
@v1_app.post("/cpm")
|
||||
async def upload_cpm(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
return await upload_file(file, task_type="cpm", api_key_info=api_key_info)
|
||||
|
||||
@v1_app.post("/qwenvl")
|
||||
async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
return await upload_file(file, task_type="qwenvl", api_key_info=api_key_info)
|
||||
|
||||
@v1_app.post("/yolo")
|
||||
async def upload_yolo(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
return await upload_file(file, task_type="yolo", api_key_info=api_key_info)
|
||||
|
||||
@v1_app.post("/fall")
|
||||
async def upload_fall(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
return await upload_file(file, task_type="fall", api_key_info=api_key_info)
|
||||
|
||||
@v1_app.post("/face")
|
||||
async def upload_face(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
logging.info(f"收到 /face 端点的请求")
|
||||
return await upload_file(file, task_type="face", api_key_info=api_key_info)
|
||||
|
||||
@v1_app.post("/mediapipe")
|
||||
async def upload_mediapipe(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
return await upload_file(file, task_type="mediapipe", api_key_info=api_key_info)
|
||||
|
||||
|
||||
@v1_app.get("/result/{task_id}")
|
||||
async def get_result(task_id: str, api_key: str = Depends(get_api_key)):
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
# 从 REDIS_DB (15) 获取任务状态
|
||||
task_info = redis_client.hgetall(f"task:{task_id}")
|
||||
if not task_info:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()}
|
||||
|
||||
if task_info['status'] != 'completed':
|
||||
return {"status": task_info['status'], "message": "Task is not completed yet"}
|
||||
|
||||
result_type = task_info['result_type']
|
||||
result_key = task_info['result_key']
|
||||
|
||||
# 根据任务类型选择相应的 Redis 客户端
|
||||
redis_client_map = {
|
||||
'pose': redis_pose_client,
|
||||
'cpm': redis_cpm_client,
|
||||
'yolo': redis_yolo_client,
|
||||
'face': redis_face_client,
|
||||
'fall': redis_fall_client,
|
||||
'mediapipe': redis_mediapipe_client,
|
||||
'qwenvl': redis_qwenvl_client
|
||||
}
|
||||
|
||||
result_redis = redis_client_map.get(result_type)
|
||||
if not result_redis:
|
||||
raise HTTPException(status_code=400, detail="Unsupported result type")
|
||||
|
||||
result = result_redis.hgetall(result_key)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found")
|
||||
|
||||
result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
|
||||
|
||||
# 将 result 字段解析为 JSON(如果存在)
|
||||
if 'result' in result:
|
||||
result['result'] = json.loads(result['result'])
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"result_type": result_type,
|
||||
"result": result
|
||||
}
|
||||
|
||||
@v1_app.get("/annotated/{task_id}")
|
||||
async def get_annotated_image(task_id: str, api_key: str = Depends(get_api_key)):
|
||||
api_key_info = await verify_api_key(api_key)
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||
|
||||
# 从 REDIS_DB (15) 获取任务信息
|
||||
task_info = redis_client.hgetall(f"task:{task_id}")
|
||||
if not task_info:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()}
|
||||
|
||||
if task_info['status'] != 'completed':
|
||||
raise HTTPException(status_code=400, detail="Task is not completed yet")
|
||||
|
||||
result_type = task_info.get('result_type')
|
||||
result_key = task_info.get('result_key')
|
||||
|
||||
if not result_key:
|
||||
raise HTTPException(status_code=404, detail="Result key not found")
|
||||
|
||||
if result_type in ['cpm', 'qwenvl']:
|
||||
raise HTTPException(status_code=400, detail="Annotated image not available for this task type")
|
||||
|
||||
# 根据任务类型选择相应的 Redis 客户端
|
||||
redis_client_map = {
|
||||
'pose': redis_pose_client,
|
||||
'yolo': redis_yolo_client,
|
||||
'face': redis_face_client,
|
||||
'fall': redis_fall_client,
|
||||
'mediapipe': redis_mediapipe_client
|
||||
}
|
||||
|
||||
result_redis = redis_client_map.get(result_type)
|
||||
if not result_redis:
|
||||
raise HTTPException(status_code=400, detail="Unsupported result type")
|
||||
|
||||
result = result_redis.hgetall(result_key)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found")
|
||||
|
||||
result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
|
||||
|
||||
result_file = result.get('result_file')
|
||||
if not result_file:
|
||||
raise HTTPException(status_code=404, detail="Result file not found")
|
||||
|
||||
file_path = os.path.join(RESULT_DIR, result_file)
|
||||
if not os.path.exists(file_path):
|
||||
raise HTTPException(status_code=404, detail="Result image file not found")
|
||||
|
||||
return FileResponse(file_path, media_type="image/png")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8005)
|
||||
@@ -1,88 +0,0 @@
|
||||
from flask import Flask, request, send_file, jsonify
|
||||
import ChatTTS
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from flask_cors import CORS
|
||||
import os
|
||||
import pickle
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
import hashlib
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
chat_tts = ChatTTS.Chat()
|
||||
chat_tts.load(compile=False)
|
||||
|
||||
SAMPLE_RATE = 24000
|
||||
|
||||
SPEAKER_EMBEDDING_FILE = 'cutegirl_speaker_embedding.pkl'
|
||||
AUDIO_DIR = '/www/wwwroot/chat.obscura.work/audio_files'
|
||||
|
||||
with open(SPEAKER_EMBEDDING_FILE, 'rb') as f:
|
||||
FIXED_SPEAKER = pickle.load(f)
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=3)
|
||||
|
||||
def generate_audio(text):
|
||||
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
||||
spk_emb=FIXED_SPEAKER,
|
||||
temperature=0.3,
|
||||
top_P=0.6,
|
||||
top_K=20,
|
||||
)
|
||||
|
||||
wavs = chat_tts.infer(text, params_infer_code=params_infer_code)
|
||||
audio_data = wavs[0]
|
||||
|
||||
if not np.issubdtype(audio_data.dtype, np.floating):
|
||||
audio_data = audio_data.astype(np.float32)
|
||||
|
||||
if np.max(np.abs(audio_data)) > 1:
|
||||
audio_data = audio_data / np.max(np.abs(audio_data))
|
||||
|
||||
return audio_data
|
||||
|
||||
def get_audio_filename(text):
|
||||
return hashlib.md5(text.encode()).hexdigest() + '.wav'
|
||||
|
||||
@app.route('/synthesize', methods=['POST', 'OPTIONS'])
|
||||
async def synthesize():
|
||||
if request.method == 'OPTIONS':
|
||||
return '', 204
|
||||
|
||||
data = request.json
|
||||
texts = data.get('texts')
|
||||
if not texts:
|
||||
return jsonify({"error": "No texts provided"}), 400
|
||||
|
||||
audio_urls = []
|
||||
|
||||
for text in texts:
|
||||
filename = get_audio_filename(text)
|
||||
filepath = os.path.join(AUDIO_DIR, filename)
|
||||
|
||||
if os.path.exists(filepath):
|
||||
audio_urls.append(f"/audio_files/{filename}")
|
||||
else:
|
||||
loop = asyncio.get_event_loop()
|
||||
audio_data = await loop.run_in_executor(executor, generate_audio, text)
|
||||
sf.write(filepath, audio_data, SAMPLE_RATE)
|
||||
audio_urls.append(f"/audio_files/{filename}")
|
||||
|
||||
return jsonify({"audio_urls": audio_urls})
|
||||
|
||||
@app.route('/audio_files/<filename>', methods=['GET', 'OPTIONS'])
|
||||
def get_audio(filename):
|
||||
if request.method == 'OPTIONS':
|
||||
return '', 204
|
||||
try:
|
||||
return send_file(os.path.join(AUDIO_DIR, filename), mimetype='audio/wav')
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 404
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(port=5002)
|
||||
@@ -1,92 +0,0 @@
|
||||
from flask import Flask, request, send_file, jsonify
|
||||
import ChatTTS
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from flask_cors import CORS
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import pickle
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
# 初始化 ChatTTS
|
||||
chat_tts = ChatTTS.Chat()
|
||||
chat_tts.load(compile=False)
|
||||
|
||||
# 定义采样率
|
||||
SAMPLE_RATE = 24000
|
||||
|
||||
# # 生成一个固定的说话人嵌入
|
||||
# FIXED_SPEAKER = chat_tts.sample_random_speaker()
|
||||
|
||||
# 文件名用于保存和加载说话人嵌入
|
||||
SPEAKER_EMBEDDING_FILE = 'two_speaker_embedding.pkl'
|
||||
|
||||
def get_or_create_fixed_speaker():
|
||||
try:
|
||||
if os.path.exists(SPEAKER_EMBEDDING_FILE):
|
||||
with open(SPEAKER_EMBEDDING_FILE, 'rb') as f:
|
||||
fixed_speaker = pickle.load(f)
|
||||
else:
|
||||
fixed_speaker = chat_tts.sample_random_speaker()
|
||||
with open(SPEAKER_EMBEDDING_FILE, 'wb') as f:
|
||||
pickle.dump(fixed_speaker, f)
|
||||
except (EOFError, pickle.UnpicklingError):
|
||||
print("Warning: Unable to load speaker embedding. Creating a new one.")
|
||||
fixed_speaker = chat_tts.sample_random_speaker()
|
||||
with open(SPEAKER_EMBEDDING_FILE, 'wb') as f:
|
||||
pickle.dump(fixed_speaker, f)
|
||||
return fixed_speaker
|
||||
|
||||
|
||||
# 获取或创建固定的说话人嵌入
|
||||
FIXED_SPEAKER = get_or_create_fixed_speaker()
|
||||
|
||||
@app.route('/synthesize', methods=['POST'])
|
||||
def synthesize():
|
||||
data = request.json
|
||||
text = data.get('text')
|
||||
if not text:
|
||||
return jsonify({"error": "No text provided"}), 400
|
||||
|
||||
temp_file = None
|
||||
try:
|
||||
params_infer_code = ChatTTS.Chat.InferCodeParams(
|
||||
spk_emb=FIXED_SPEAKER,
|
||||
temperature=0.3,
|
||||
top_P=0.7,
|
||||
top_K=20,
|
||||
)
|
||||
|
||||
params_refine_text = ChatTTS.Chat.RefineTextParams(
|
||||
prompt='[oral_2][laugh_0][break_6]',
|
||||
)
|
||||
|
||||
wavs = chat_tts.infer(text, params_refine_text=params_refine_text, params_infer_code=params_infer_code)
|
||||
|
||||
audio_data = wavs[0]
|
||||
|
||||
if not np.issubdtype(audio_data.dtype, np.floating):
|
||||
audio_data = audio_data.astype(np.float32)
|
||||
|
||||
if np.max(np.abs(audio_data)) > 1:
|
||||
audio_data = audio_data / np.max(np.abs(audio_data))
|
||||
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
||||
sf.write(temp_file.name, audio_data, SAMPLE_RATE)
|
||||
|
||||
return send_file(temp_file.name, mimetype='audio/wav')
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
finally:
|
||||
if temp_file and os.path.exists(temp_file.name):
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(port=5002)
|
||||
@@ -1,49 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import soundfile as sf
|
||||
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text, target_language, output_path, output_filename):
|
||||
# Read reference text
|
||||
with open(ref_text_path, 'r', encoding='utf-8') as file:
|
||||
ref_text = file.read()
|
||||
|
||||
# Change model weights
|
||||
change_gpt_weights(gpt_path=GPT_model_path)
|
||||
change_sovits_weights(sovits_path=SoVITS_model_path)
|
||||
|
||||
# Synthesize audio
|
||||
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(ref_language),
|
||||
text=target_text,
|
||||
text_language=i18n(target_language), top_p=1, temperature=1)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
output_wav_path = os.path.join(output_path, output_filename)
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
print(f"Audio saved to {output_wav_path}")
|
||||
|
||||
def main():
|
||||
GPT_model_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"
|
||||
SoVITS_model_path = "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"
|
||||
ref_audio_path = "/home/zydi//worker_chat/kafka/sample/woman.wav"
|
||||
ref_text_path = "/home/zydi//worker_chat/kafka/sample/woman.txt"
|
||||
ref_language = "中文"
|
||||
target_text = """我们开发了"病人实时健康监测系统"和"AI辅助诊断系统",这些系统显著提高了医疗诊断的效率和准确性。obscura形成了全面的医疗智能解决方案"""
|
||||
|
||||
target_language = "多语种混合"
|
||||
output_path = "/home/zydi//worker_chat/kafka"
|
||||
output_filename = "output.wav"
|
||||
|
||||
synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text, target_language, output_path, output_filename)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,210 +0,0 @@
|
||||
import json
|
||||
import threading
|
||||
import redis
|
||||
from kafka import KafkaConsumer, KafkaProducer
|
||||
from pymongo import MongoClient,TEXT
|
||||
import requests
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁的方式回答问题。当用户询问数据库相关内容时,你可以访问MongoDB数据库来获取额外的信息。"
|
||||
|
||||
def search_mongodb(mongodb_client, query, mongo_db, search_collection):
|
||||
collection = mongodb_client[mongo_db][search_collection]
|
||||
results = collection.find({"$text": {"$search": query}})
|
||||
|
||||
processed_results = []
|
||||
for result in results:
|
||||
if 'original_answer' in result:
|
||||
processed_results.append(result['original_answer'])
|
||||
else:
|
||||
# 记录这个问题并跳过这个结果
|
||||
print(f"警告: 文档缺少 'original_answer' 字段: {result['_id']}")
|
||||
|
||||
return processed_results
|
||||
|
||||
def get_conversation_history(redis_client, conversation_id, max_history=5):
|
||||
history = redis_client.lrange(f"conversation:{conversation_id}", 0, max_history * 2 - 1)
|
||||
return list(zip(history[::2], history[1::2]))
|
||||
|
||||
def add_to_conversation_history(redis_client, conversation_id, query, answer):
|
||||
redis_client.rpush(f"conversation:{conversation_id}", query, answer)
|
||||
redis_client.expire(f"conversation:{conversation_id}", 3600) # 设置1小时的过期时间
|
||||
|
||||
def generate_answer(query, context, history, mongo_client, mongo_db, search_collection):
|
||||
full_prompt = DEFAULT_SYSTEM_PROMPT + "\n"
|
||||
for past_query, past_response in history:
|
||||
full_prompt += f"用户: {past_query}\n助手: {past_response}\n"
|
||||
full_prompt += f"用户: {query}\n上下文: {context}\n\n"
|
||||
full_prompt += "请根据上下文和历史对话回答用户的问题。如果需要额外信息,你可以使用以下函数查询MongoDB数据库:\n"
|
||||
full_prompt += "search_mongodb(query: str) -> List[str]\n"
|
||||
full_prompt += "该函数会返回与查询相关的内容列表。\n"
|
||||
full_prompt += "助手: "
|
||||
|
||||
def search_mongodb_wrapper(query):
|
||||
return search_mongodb(mongo_client, query, mongo_db, search_collection)
|
||||
|
||||
data = {
|
||||
"model": "llama3.1",
|
||||
"prompt": full_prompt,
|
||||
"stream": True,
|
||||
"temperature": 0,
|
||||
"functions": [
|
||||
{
|
||||
"name": "search_mongodb",
|
||||
"description": "Search for information in MongoDB",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
],
|
||||
"function_call": "auto"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post("http://127.0.0.1:11434/api/generate", json=data, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
text_output = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
json_data = json.loads(line)
|
||||
if 'response' in json_data:
|
||||
text_output += json_data['response']
|
||||
elif 'function_call' in json_data:
|
||||
function_call = json.loads(json_data['function_call'])
|
||||
if function_call['name'] == 'search_mongodb':
|
||||
search_results = search_mongodb_wrapper(function_call['arguments']['query'])
|
||||
text_output += f"根据MongoDB查询结果:{', '.join(search_results)}\n"
|
||||
|
||||
return text_output
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error generating answer: {str(e)}")
|
||||
return "抱歉,生成回答时出现错误。"
|
||||
|
||||
def process_message(message, kafka_config, mongodb_client, redis_client, mongo_db, search_collection):
|
||||
try:
|
||||
# 检查 message.value 是否已经是字典
|
||||
if isinstance(message.value, dict):
|
||||
message_data = message.value
|
||||
else:
|
||||
# 如果不是字典,尝试解析为 JSON
|
||||
message_data = json.loads(message.value)
|
||||
|
||||
query = message_data.get('text', '')
|
||||
conversation_id = message_data.get('conversation_id', 'default')
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
# 如果解析失败或 message.value 不是预期的格式,假设整个消息就是查询文本
|
||||
query = str(message.value)
|
||||
conversation_id = 'default'
|
||||
|
||||
# 获取对话历史
|
||||
history = get_conversation_history(redis_client, conversation_id)
|
||||
|
||||
# 在 MongoDB 中搜索相关信息
|
||||
search_results = search_mongodb(mongodb_client, query, mongo_db, search_collection)
|
||||
context = " ".join(search_results)
|
||||
|
||||
# 使用 llama3.1:8b 模型生成答案
|
||||
answer = generate_answer(query, context, history, mongodb_client, mongo_db, search_collection)
|
||||
|
||||
# 将对话添加到 Redis 历史记录
|
||||
add_to_conversation_history(redis_client, conversation_id, query, answer)
|
||||
|
||||
# 将答案发送到 Kafka 的 voice-output 主题
|
||||
producer = KafkaProducer(bootstrap_servers=[kafka_config['bootstrap_servers']],
|
||||
value_serializer=lambda x: json.dumps(x).encode('utf-8'))
|
||||
producer.send(kafka_config['voice_output_topic'], {'answer': answer, 'conversation_id': conversation_id})
|
||||
|
||||
print(f"Processed message: {query}")
|
||||
print(f"Generated answer: {answer}")
|
||||
print(f"Sent to voice-output topic: {{'answer': '{answer}', 'conversation_id': '{conversation_id}'}}")
|
||||
def consumer_thread(kafka_config, mongodb_config, redis_config, consumer_group, thread_id):
|
||||
consumer = KafkaConsumer(
|
||||
kafka_config['text_input_topic'],
|
||||
bootstrap_servers=[kafka_config['bootstrap_servers']],
|
||||
group_id=consumer_group,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True
|
||||
)
|
||||
|
||||
mongodb_client = MongoClient(mongodb_config['uri'])
|
||||
redis_client = redis.Redis(
|
||||
host=redis_config['host'],
|
||||
port=redis_config['port'],
|
||||
db=redis_config['db'],
|
||||
password=redis_config['password']
|
||||
)
|
||||
|
||||
print(f"Consumer thread {thread_id} started, listening to {kafka_config['text_input_topic']} topic, consumer group: {consumer_group}")
|
||||
|
||||
for message in consumer:
|
||||
print(f"Thread {thread_id} received message in partition {message.partition}, offset: {message.offset}")
|
||||
process_message(message, kafka_config, mongodb_client, redis_client,
|
||||
mongodb_config['db_name'], mongodb_config['search_collection'])
|
||||
|
||||
def main(kafka_config, mongodb_config, redis_config):
|
||||
threads = []
|
||||
consumer_group = f"{kafka_config['consumer_group_prefix']}_single"
|
||||
|
||||
for i in range(kafka_config['num_threads']):
|
||||
thread = threading.Thread(
|
||||
target=consumer_thread,
|
||||
args=(kafka_config, mongodb_config, redis_config, consumer_group, i)
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
print(f"Started consumer thread {i}, consumer group: {consumer_group}")
|
||||
|
||||
# Wait for all threads to complete (in reality, they will run indefinitely)
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Kafka configuration
|
||||
KAFKA_BOOTSTRAP_SERVERS = '222.186.136.78:9092'
|
||||
KAFKA_INPUT_TOPIC = 'text-input'
|
||||
KAFKA_OUTPUT_TOPIC = 'voice-output'
|
||||
KAFKA_CONSUMER_GROUP_PREFIX = 'text_group'
|
||||
KAFKA_NUM_THREADS = 3 # 您可以根据需要调整线程数
|
||||
|
||||
# MongoDB configuration
|
||||
MONGO_URI = 'mongodb://minio_mongo:BCd4npzKBnwmCRdh@222.186.136.78:27017/?authSource=minio_mongo'
|
||||
MONGO_DB = 'minio_mongo'
|
||||
MONGO_SEARCH_COLLECTION = 'cpm'
|
||||
|
||||
# Redis configuration
|
||||
REDIS_HOST = '222.186.136.78'
|
||||
REDIS_PORT = 6379
|
||||
REDIS_DB = 0
|
||||
REDIS_PASSWORD = 'Obscura@2024' # 添加Redis密码
|
||||
|
||||
kafka_config = {
|
||||
'bootstrap_servers': KAFKA_BOOTSTRAP_SERVERS,
|
||||
'text_input_topic': KAFKA_INPUT_TOPIC,
|
||||
'voice_output_topic': KAFKA_OUTPUT_TOPIC,
|
||||
'consumer_group_prefix': KAFKA_CONSUMER_GROUP_PREFIX,
|
||||
'num_threads': KAFKA_NUM_THREADS
|
||||
}
|
||||
|
||||
mongodb_config = {
|
||||
'uri': MONGO_URI,
|
||||
'db_name': MONGO_DB,
|
||||
'search_collection': MONGO_SEARCH_COLLECTION
|
||||
}
|
||||
|
||||
redis_config = {
|
||||
'host': REDIS_HOST,
|
||||
'port': REDIS_PORT,
|
||||
'db': REDIS_DB,
|
||||
'password': REDIS_PASSWORD # 添加Redis密码
|
||||
}
|
||||
|
||||
main(kafka_config, mongodb_config, redis_config)
|
||||
@@ -1,169 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from kafka import KafkaConsumer, KafkaProducer, TopicPartition
|
||||
import whisper
|
||||
from pydub import AudioSegment
|
||||
import io
|
||||
import tempfile
|
||||
from minio import Minio
|
||||
import threading
|
||||
import requests
|
||||
|
||||
def get_audio_from_minio(minio_client, bucket, object_name):
|
||||
try:
|
||||
response = minio_client.get_object(bucket, object_name)
|
||||
return response.read()
|
||||
except Exception as e:
|
||||
print(f"从 MinIO 获取音频时出错: {str(e)}")
|
||||
print(f"Bucket: {bucket}, Object: {object_name}")
|
||||
return None
|
||||
|
||||
def process_audio(model, audio_data, file_extension):
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_extension}') as temp_audio_file:
|
||||
temp_audio_file.write(audio_data)
|
||||
temp_audio_file.flush()
|
||||
|
||||
if file_extension.lower() != 'wav':
|
||||
audio = AudioSegment.from_file(temp_audio_file.name, format=file_extension)
|
||||
wav_path = temp_audio_file.name + '.wav'
|
||||
audio.export(wav_path, format="wav")
|
||||
else:
|
||||
wav_path = temp_audio_file.name
|
||||
|
||||
result = model.transcribe(wav_path)
|
||||
|
||||
os.unlink(temp_audio_file.name)
|
||||
if file_extension.lower() != 'wav':
|
||||
os.unlink(wav_path)
|
||||
|
||||
return json.dumps({"text": result["text"]})
|
||||
|
||||
def consumer_thread(kafka_config, minio_config, model, partition):
|
||||
client_id = f'voice_{partition}'
|
||||
group_id = f"{kafka_config['consumer_group']}_{partition}"
|
||||
|
||||
consumer = KafkaConsumer(
|
||||
bootstrap_servers=[kafka_config['bootstrap_servers']],
|
||||
group_id=group_id,
|
||||
client_id=client_id,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
|
||||
enable_auto_commit=True,
|
||||
auto_commit_interval_ms=5000
|
||||
)
|
||||
|
||||
# 手动分配分区
|
||||
topic_partition = TopicPartition(kafka_config['voice_input_topic'], partition)
|
||||
consumer.assign([topic_partition])
|
||||
|
||||
producer = KafkaProducer(
|
||||
bootstrap_servers=[kafka_config['bootstrap_servers']],
|
||||
value_serializer=lambda x: json.dumps(x).encode('utf-8')
|
||||
)
|
||||
|
||||
minio_client = Minio(
|
||||
minio_config['endpoint'],
|
||||
access_key=minio_config['access_key'],
|
||||
secret_key=minio_config['secret_key'],
|
||||
secure=minio_config['secure']
|
||||
)
|
||||
|
||||
print(f"消费者 {client_id} 开始监听 {kafka_config['voice_input_topic']} 主题的分区 {partition}...")
|
||||
|
||||
|
||||
for message in consumer:
|
||||
print(f"消费者 {client_id} 从分区 {partition} 收到新的音频消息")
|
||||
event_info = message.value
|
||||
print(f"事件信息: {event_info}")
|
||||
|
||||
# 从S3事件中提取音频文件路径
|
||||
audio_path = event_info.get('Key')
|
||||
if not audio_path:
|
||||
# 如果在顶层没有找到Key,尝试从Records中获取
|
||||
records = event_info.get('Records', [])
|
||||
if records:
|
||||
audio_path = records[0].get('s3', {}).get('object', {}).get('key')
|
||||
|
||||
if not audio_path:
|
||||
print(f"消费者 {client_id} 无法从事件中获取音频路径")
|
||||
continue
|
||||
|
||||
# 移除可能的 'audio/' 前缀
|
||||
if audio_path.startswith('audio/'):
|
||||
audio_path = audio_path[6:]
|
||||
|
||||
file_extension = audio_path.split('.')[-1] if '.' in audio_path else 'wav'
|
||||
|
||||
print(f"尝试从MinIO获取音频: {audio_path}")
|
||||
audio_data = get_audio_from_minio(minio_client, minio_config['bucket'], audio_path)
|
||||
|
||||
if audio_data:
|
||||
try:
|
||||
transcribed_result = process_audio(model, audio_data, file_extension)
|
||||
transcribed_data = json.loads(transcribed_result)
|
||||
transcribed_text = transcribed_data["text"]
|
||||
print(f"消费者 {client_id} 识别结果: {transcribed_text}")
|
||||
|
||||
audio_info = {
|
||||
'file_name': audio_path,
|
||||
'file_extension': file_extension,
|
||||
'size': len(audio_data)
|
||||
}
|
||||
|
||||
# 发送到下一个Kafka主题
|
||||
producer.send(kafka_config['text_input_topic'], value=json.dumps({
|
||||
'text': transcribed_text,
|
||||
'audio_info': audio_info,
|
||||
'consumer_id': client_id,
|
||||
'partition': partition
|
||||
}))
|
||||
print(f"消费者 {client_id} 已发送识别结果到 {kafka_config['text_input_topic']} 主题")
|
||||
|
||||
# # 发送结果到PHP服务
|
||||
# send_to_php(transcribed_text, audio_info)
|
||||
|
||||
except Exception as e:
|
||||
print(f"消费者 {client_id} 处理音频时出错: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
else:
|
||||
print(f"消费者 {client_id} 无法获取音频数据")
|
||||
|
||||
consumer.close()
|
||||
|
||||
def main(kafka_config, minio_config, whisper_config):
|
||||
model = whisper.load_model(whisper_config['model_name'])
|
||||
|
||||
threads = []
|
||||
for partition in range(kafka_config['num_partitions']):
|
||||
thread = threading.Thread(target=consumer_thread, args=(kafka_config, minio_config, model, partition))
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Kafka 配置
|
||||
kafka_config = {
|
||||
'bootstrap_servers': '222.186.136.78:9092',
|
||||
'voice_input_topic': 'voice-input',
|
||||
'text_input_topic': 'text-input',
|
||||
'consumer_group': 'voice_group',
|
||||
'num_partitions': 3 # 修改为实际的分区数
|
||||
}
|
||||
|
||||
# MinIO 配置
|
||||
minio_config = {
|
||||
'endpoint': "api.obscura.work",
|
||||
'access_key': "00v3MtLtIAIkR3hkIuYR",
|
||||
'secret_key': "XfDeVe5bJjPU21NEYc023gzJVUTJzQqxsWHqIKMf",
|
||||
'bucket': 'audio',
|
||||
'secure': True
|
||||
}
|
||||
# Whisper 配置
|
||||
whisper_config = {
|
||||
'model_name': 'large-v3' # 可以根据需要选择不同的模型大小
|
||||
}
|
||||
|
||||
main(kafka_config, minio_config, whisper_config)
|
||||
@@ -1,202 +0,0 @@
|
||||
import json
|
||||
import hashlib
|
||||
import io
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from kafka import KafkaConsumer, KafkaProducer
|
||||
from minio import Minio
|
||||
from GPT_SoVITS.inference_webui import get_tts_wav
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
import soundfile as sf
|
||||
import redis
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
# Global variables
|
||||
global_model_config = None
|
||||
|
||||
def initialize_models(model_config):
|
||||
global global_model_config
|
||||
global_model_config = model_config
|
||||
print("Models initialized")
|
||||
|
||||
def generate_content_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
def synthesize(target_text):
|
||||
global global_model_config
|
||||
|
||||
with open(global_model_config['ref_text_path'], 'r', encoding='utf-8') as file:
|
||||
ref_text = file.read()
|
||||
|
||||
synthesis_result = get_tts_wav(
|
||||
ref_wav_path=global_model_config['ref_audio_path'],
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(global_model_config['ref_language']),
|
||||
text=target_text,
|
||||
text_language=i18n(global_model_config['target_language']),
|
||||
top_p=1, temperature=1
|
||||
)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
print(f"Synthesizing audio for text: {target_text}")
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
return last_sampling_rate, last_audio_data
|
||||
return None, None
|
||||
|
||||
def process_message(message_data, minio_client, minio_config, redis_client):
|
||||
target_text = message_data.get('answer', message_data.get('text', ''))
|
||||
content_hash = generate_content_hash(target_text)
|
||||
|
||||
# Check Redis cache
|
||||
cached_audio_info = redis_client.get(content_hash)
|
||||
if cached_audio_info:
|
||||
print(f"Using Redis cached audio, content hash: {content_hash}")
|
||||
return json.loads(cached_audio_info)
|
||||
|
||||
# If not in Redis, check MinIO
|
||||
bucket_name = minio_config['bucket']
|
||||
object_name = f"{content_hash}.wav"
|
||||
try:
|
||||
minio_client.stat_object(bucket_name, object_name)
|
||||
print(f"Using existing audio from MinIO, content hash: {content_hash}")
|
||||
audio_info = {
|
||||
'id': content_hash,
|
||||
'text': target_text,
|
||||
'content_hash': content_hash,
|
||||
'minio_bucket': bucket_name,
|
||||
'minio_object': object_name,
|
||||
'status': 'completed',
|
||||
}
|
||||
redis_client.set(content_hash, json.dumps(audio_info))
|
||||
return audio_info
|
||||
except:
|
||||
pass # Object doesn't exist in MinIO, continue to synthesis
|
||||
|
||||
# If not found, synthesize new audio
|
||||
try:
|
||||
sampling_rate, audio_data = synthesize(target_text)
|
||||
|
||||
if audio_data is not None:
|
||||
audio_id = content_hash
|
||||
object_name = f"{audio_id}.wav"
|
||||
audio_buffer = io.BytesIO()
|
||||
sf.write(audio_buffer, audio_data, sampling_rate, format='wav')
|
||||
audio_buffer.seek(0)
|
||||
|
||||
minio_client.put_object(
|
||||
bucket_name, object_name, audio_buffer,
|
||||
length=audio_buffer.getbuffer().nbytes
|
||||
)
|
||||
|
||||
etag = minio_client.stat_object(bucket_name, object_name).etag
|
||||
|
||||
audio_info = {
|
||||
'id': audio_id,
|
||||
'text': target_text,
|
||||
'sampling_rate': sampling_rate,
|
||||
'content_hash': content_hash,
|
||||
'minio_bucket': bucket_name,
|
||||
'minio_object': object_name,
|
||||
'etag': etag,
|
||||
'status': 'completed',
|
||||
}
|
||||
|
||||
redis_client.set(content_hash, json.dumps(audio_info))
|
||||
|
||||
return audio_info
|
||||
except Exception as e:
|
||||
print(f"Error processing message: {e}")
|
||||
error_info = {
|
||||
'status': 'failed',
|
||||
'error': str(e),
|
||||
'text': target_text,
|
||||
'content_hash': content_hash,
|
||||
}
|
||||
redis_client.set(content_hash, json.dumps(error_info))
|
||||
return None
|
||||
|
||||
def message_handler(message, minio_client, minio_config, redis_client):
|
||||
print(f"Processing message: {message.value}")
|
||||
message_data = json.loads(message.value)
|
||||
audio_info = process_message(message_data, minio_client, minio_config, redis_client)
|
||||
|
||||
def consumer_thread(consumer_id, kafka_config, minio_config, redis_config):
|
||||
consumer = KafkaConsumer(
|
||||
kafka_config['text_input_topic'],
|
||||
bootstrap_servers=kafka_config['bootstrap_servers'],
|
||||
auto_offset_reset='latest',
|
||||
enable_auto_commit=True,
|
||||
group_id=kafka_config['consumer_group']
|
||||
)
|
||||
|
||||
minio_client = Minio(
|
||||
minio_config['endpoint'],
|
||||
access_key=minio_config['access_key'],
|
||||
secret_key=minio_config['secret_key'],
|
||||
secure=minio_config['secure']
|
||||
)
|
||||
|
||||
redis_client = redis.Redis(
|
||||
host=redis_config['host'],
|
||||
port=redis_config['port'],
|
||||
db=redis_config['db'],
|
||||
password=redis_config['password']
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=kafka_config['threads_per_consumer']) as executor:
|
||||
print(f"Consumer {consumer_id} started running")
|
||||
for message in consumer:
|
||||
executor.submit(message_handler, message, minio_client, minio_config, redis_client)
|
||||
|
||||
def main(kafka_config, minio_config, model_config, redis_config):
|
||||
initialize_models(model_config)
|
||||
|
||||
threads = []
|
||||
for i in range(kafka_config['num_consumers']):
|
||||
t = threading.Thread(target=consumer_thread, args=(i, kafka_config, minio_config, redis_config))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Kafka configuration
|
||||
kafka_config = {
|
||||
'bootstrap_servers': '222.186.136.78:9092',
|
||||
'text_input_topic': 'voice-output',
|
||||
'consumer_group': 'voice_group',
|
||||
'num_consumers': 3,
|
||||
'threads_per_consumer': 4
|
||||
}
|
||||
|
||||
# MinIO configuration
|
||||
minio_config = {
|
||||
'endpoint': "api.obscura.work",
|
||||
'access_key': "00v3MtLtIAIkR3hkIuYR",
|
||||
'secret_key': "XfDeVe5bJjPU21NEYc023gzJVUTJzQqxsWHqIKMf",
|
||||
'bucket': 'tts-audio',
|
||||
'secure': True
|
||||
}
|
||||
|
||||
# Redis configuration
|
||||
redis_config = {
|
||||
'host': '222.186.136.78',
|
||||
'port': 6379,
|
||||
'db': 4,
|
||||
'password': "Obscura@2024"
|
||||
}
|
||||
|
||||
# Model configuration
|
||||
model_config = {
|
||||
'GPT_model_path': "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
|
||||
'SoVITS_model_path': "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
|
||||
'ref_audio_path': "sample/woman.wav",
|
||||
'ref_text_path': "sample/woman.txt",
|
||||
'ref_language': "中文",
|
||||
'target_language': "多语种混合"
|
||||
}
|
||||
|
||||
main(kafka_config, minio_config, model_config, redis_config)
|
||||
@@ -1 +0,0 @@
|
||||
{"GPT": {"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"}, "SoVITS": {"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"}}
|
||||
@@ -1,406 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security import APIKeyHeader
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
from kafka import KafkaProducer
|
||||
from redis import Redis
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from dotenv import load_dotenv
|
||||
import tempfile
|
||||
import hashlib
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# 在文件顶部添加这个函数
|
||||
def get_audio_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI()
|
||||
v1_chat_app = FastAPI()
|
||||
app.mount("/v1_chat", v1_chat_app)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB'))
|
||||
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
|
||||
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
|
||||
REDIS_API_DB = int(os.getenv('REDIS_API_DB'))
|
||||
REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB'))
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||
|
||||
|
||||
# Redis 配置
|
||||
REDIS_GIRL_DB = int(os.getenv('REDIS_GIRL_DB'))
|
||||
REDIS_WOMAN_DB = int(os.getenv('REDIS_WOMAN_DB'))
|
||||
REDIS_MAN_DB = int(os.getenv('REDIS_MAN_DB'))
|
||||
REDIS_LEIJUN_DB = int(os.getenv('REDIS_LEIJUN_DB'))
|
||||
REDIS_DUFU_DB = int(os.getenv('REDIS_DUFU_DB'))
|
||||
REDIS_HEJIONG_DB = int(os.getenv('REDIS_HEJIONG_DB'))
|
||||
REDIS_MAHUATENG_DB = int(os.getenv('REDIS_MAHUATENG_DB'))
|
||||
REDIS_LIDAN_DB = int(os.getenv('REDIS_LIDAN_DB'))
|
||||
REDIS_YUHUA_DB = int(os.getenv('REDIS_YUHUA_DB'))
|
||||
REDIS_LIUZHENYUN_DB = int(os.getenv('REDIS_LIUZHENYUN_DB'))
|
||||
REDIS_DABING_DB = int(os.getenv('REDIS_DABING_DB'))
|
||||
REDIS_LUOXIANG_DB = int(os.getenv('REDIS_LUOXIANG_DB'))
|
||||
REDIS_XUZHIYUAN_DB = int(os.getenv('REDIS_XUZHIYUAN_DB'))
|
||||
|
||||
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||
KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
||||
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
|
||||
|
||||
OUTPUT_PATH= os.getenv('OUTPUT_PATH')
|
||||
|
||||
# 初始化 Kafka Producer
|
||||
producer = KafkaProducer(
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB)
|
||||
redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB)
|
||||
redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB)
|
||||
redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB)
|
||||
redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB)
|
||||
redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB)
|
||||
|
||||
redis_tts_girl = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_GIRL_DB)
|
||||
redis_tts_woman = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_WOMAN_DB)
|
||||
redis_tts_man = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAN_DB)
|
||||
redis_tts_leijun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LEIJUN_DB)
|
||||
redis_tts_dufu = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DUFU_DB)
|
||||
redis_tts_hejiong = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_HEJIONG_DB)
|
||||
redis_tts_mahuateng = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAHUATENG_DB)
|
||||
redis_tts_lidan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIDAN_DB)
|
||||
redis_tts_yuhua = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_YUHUA_DB)
|
||||
redis_tts_liuzhenyun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIUZHENYUN_DB)
|
||||
redis_tts_dabing = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DABING_DB)
|
||||
redis_tts_luoxiang = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LUOXIANG_DB)
|
||||
redis_tts_xuzhiyuan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_XUZHIYUAN_DB)
|
||||
|
||||
# 创建一个音色到对应 Redis 客户端的映射
|
||||
voice_to_redis = {
|
||||
"default": redis_tts_girl,
|
||||
"girl": redis_tts_girl,
|
||||
"woman": redis_tts_woman,
|
||||
"man": redis_tts_man,
|
||||
"leijun": redis_tts_leijun,
|
||||
"dufu": redis_tts_dufu,
|
||||
"hejiong": redis_tts_hejiong,
|
||||
"mahuateng": redis_tts_mahuateng,
|
||||
"lidan": redis_tts_lidan,
|
||||
"yuhua": redis_tts_yuhua,
|
||||
"liuzhenyun": redis_tts_liuzhenyun,
|
||||
"dabing": redis_tts_dabing,
|
||||
"luoxiang": redis_tts_luoxiang,
|
||||
"xuzhiyuan": redis_tts_xuzhiyuan
|
||||
}
|
||||
|
||||
|
||||
# 定义API密钥头部
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
def get_audio_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# 验证API密钥的函数
|
||||
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||
if api_key and api_key.startswith("Bearer "):
|
||||
key = api_key.split(" ")[1]
|
||||
if key.startswith("obs-"):
|
||||
return key
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的API密钥",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
"api_key": api_key,
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str):
|
||||
api_key = api_key_info['api_key']
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
# 发送任务到Kafka
|
||||
producer.send(kafka_topic, task_data)
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return {
|
||||
"message": f"{model_name.upper()}请求已排队等待处理",
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
}
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str
|
||||
voice: str = Field(..., description="选择的音色")
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str
|
||||
query: str
|
||||
model: str = "qwen2.5:3b"
|
||||
|
||||
|
||||
@v1_chat_app.post("/tts")
|
||||
async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
text_hash = get_audio_hash(request.text)
|
||||
|
||||
# 验证音色选择
|
||||
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
|
||||
if request.voice not in valid_voices:
|
||||
raise HTTPException(status_code=400, detail="无效的音色选择")
|
||||
|
||||
# 如果声音是 'default',则将其视为 'girl'
|
||||
voice = 'girl' if request.voice == 'default' else request.voice
|
||||
|
||||
# 使用对应音色的 Redis 客户端
|
||||
redis_tts = voice_to_redis[request.voice]
|
||||
|
||||
# 检查是否已存在相同内容的音频文件
|
||||
existing_audio_info = redis_tts.get(f"tts:{text_hash}")
|
||||
if existing_audio_info:
|
||||
existing_audio_path = json.loads(existing_audio_info)['path']
|
||||
if os.path.exists(existing_audio_path):
|
||||
return {
|
||||
"message": "TTS请求已完成",
|
||||
"task_id": task_id,
|
||||
"status": "completed",
|
||||
"audio_path": existing_audio_path
|
||||
}
|
||||
|
||||
# 如果不存在,创建新的任务
|
||||
task_data = {
|
||||
"task_id": task_id,
|
||||
"text": request.text,
|
||||
"text_hash": text_hash,
|
||||
"voice": request.voice,
|
||||
"status": "queued",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# 存储任务信息到Redis
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "tts", 1, task_data, KAFKA_TTS_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
return result
|
||||
@v1_chat_app.post("/asr")
|
||||
async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
UPLOAD_DIR = "/obscura/task/audio_upload"
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav")
|
||||
|
||||
with open(file_path, "wb") as temp_audio:
|
||||
content = await audio.read()
|
||||
temp_audio.write(content)
|
||||
|
||||
task_data = {
|
||||
'file_path': file_path,
|
||||
'task_id': task_id,
|
||||
'status': 'queued'
|
||||
}
|
||||
|
||||
# 存储任务状态,使用一致的键名格式
|
||||
redis_task_client.set(f"task_status:asr:{task_id}", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "asr", 1, task_data, KAFKA_ASR_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
return result
|
||||
|
||||
@v1_chat_app.post("/chat")
|
||||
async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
task_data = {
|
||||
"task_id": task_id,
|
||||
"session_id": request.session_id,
|
||||
"query": request.query,
|
||||
"model": request.model,
|
||||
"status": "queued",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# 设置任务状态为 "queued"
|
||||
redis_task_client.set(f"chat:{task_id}:status", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "chat", 1, task_data, KAFKA_CHAT_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
return result
|
||||
|
||||
|
||||
@v1_chat_app.get("/chat_result/{task_id}")
|
||||
async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
# 从Redis任务数据库获取任务状态
|
||||
task_status = redis_task_client.get(f"chat:{task_id}:status")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从Redis任务数据库获取聊天结果
|
||||
chat_result = redis_task_client.get(f"chat:{task_id}:result")
|
||||
if chat_result:
|
||||
result = json.loads(chat_result)
|
||||
return {
|
||||
"status": "completed",
|
||||
"result": result
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
@v1_chat_app.get("/tts_result/{task_id}")
|
||||
async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
|
||||
if task_info:
|
||||
task_data = json.loads(task_info)
|
||||
text_hash = task_data['text_hash']
|
||||
voice = task_data['voice']
|
||||
# 'default' 和 'girl' 都使用 girl 的 Redis
|
||||
redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice]
|
||||
|
||||
audio_info = redis_tts.get(f"tts:{text_hash}")
|
||||
if audio_info:
|
||||
audio_path = json.loads(audio_info)['path']
|
||||
return {
|
||||
"status": "completed",
|
||||
"audio_path": audio_path
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
@v1_chat_app.get("/asr_result/{task_id}")
|
||||
async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
# 从Redis任务数据库获取任务状态,使用一致的键名格式
|
||||
task_status = redis_task_client.get(f"task_status:asr:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从Redis ASR结果数据库获取转录结果
|
||||
transcription = redis_asr_client.get(f"asr:{task_id}")
|
||||
if transcription:
|
||||
return {
|
||||
"status": "completed",
|
||||
"transcription": transcription.decode('utf-8')
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
@v1_chat_app.get("/tts_audio/{task_id}")
|
||||
async def get_tts_audio(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从任务信息中获取使用的音色
|
||||
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
|
||||
if task_info:
|
||||
task_data = json.loads(task_info)
|
||||
voice = task_data.get('voice', 'girl') # 默认使用 'girl'
|
||||
# 'default' 和 'girl' 都使用 girl 的 Redis
|
||||
redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice]
|
||||
|
||||
# 从对应音色的 Redis 获取音频文件路径
|
||||
audio_info = redis_tts.get(f"tts:{task_data['text_hash']}")
|
||||
if audio_info:
|
||||
audio_path = json.loads(audio_info)['path']
|
||||
if os.path.exists(audio_path):
|
||||
file_name = os.path.basename(audio_path)
|
||||
return FileResponse(audio_path, media_type="audio/wav", filename=file_name)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="音频文件不存在")
|
||||
elif status == "queued" or status == "processing":
|
||||
raise HTTPException(status_code=202, detail="音频文件正在生成中")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="任务处理失败")
|
||||
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
@v1_chat_app.get("/getvoice")
|
||||
async def get_available_voices(api_key_info: dict = Depends(verify_api_key)):
|
||||
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
|
||||
return {"available_voices": valid_voices}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8008)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user