369 lines
13 KiB
Python
369 lines
13 KiB
Python
import os
|
|
import json
|
|
import uuid
|
|
from datetime import datetime, timedelta
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from kafka import KafkaProducer, KafkaConsumer
|
|
from transformers import AutoTokenizer, AutoModel
|
|
from decord import VideoReader, cpu
|
|
from PIL import Image
|
|
from redis import Redis
|
|
import io
|
|
import re
|
|
import torch
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
import threading
|
|
|
|
app = FastAPI()
|
|
cpm_app = FastAPI()
|
|
app.mount("/cpm", cpm_app)
|
|
|
|
# CORS设置
|
|
ALLOWED_ORIGINS = ['https://beta.obscura.work']
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=ALLOWED_ORIGINS,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# 配置
|
|
MODEL_PATH = "/home/zydi/worker_sys/OpenBMB/MiniCPM-V-2_6"
|
|
KAFKA_BROKER = "222.186.10.253:9092"
|
|
KAFKA_TOPIC = "cpm"
|
|
KAFKA_GROUP_ID = "cpm_group"
|
|
|
|
REDIS_HOST = "222.186.10.253"
|
|
REDIS_PORT = 6379
|
|
REDIS_PASSWORD = "Obscura@2024"
|
|
REDIS_DB = 5
|
|
|
|
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
|
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
|
MAX_FILE_AGE = timedelta(hours=1)
|
|
|
|
# 确保目录存在
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|
os.makedirs(RESULT_DIR, exist_ok=True)
|
|
|
|
# 初始化 Kafka
|
|
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
|
consumer = KafkaConsumer(
|
|
KAFKA_TOPIC,
|
|
bootstrap_servers=[KAFKA_BROKER],
|
|
group_id=KAFKA_GROUP_ID,
|
|
auto_offset_reset='earliest',
|
|
enable_auto_commit=True,
|
|
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
|
)
|
|
|
|
# 初始化 Redis
|
|
redis_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=REDIS_DB
|
|
)
|
|
|
|
# 设置 GPU 设备
|
|
torch.cuda.set_device(0)
|
|
|
|
# 初始化模型
|
|
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
|
model = model.half().cuda().eval()
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
|
|
|
class MediaAnalysisSystem:
|
|
def __init__(self, model, tokenizer):
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.device = torch.device("cuda:0")
|
|
self.model = self.model.to(self.device)
|
|
self.MAX_NUM_FRAMES = 16
|
|
|
|
def encode_video(self, video_data):
|
|
def uniform_sample(l, n):
|
|
gap = len(l) / n
|
|
return [l[int(i * gap + gap / 2)] for i in range(n)]
|
|
|
|
video_file = io.BytesIO(video_data)
|
|
vr = VideoReader(video_file, ctx=cpu(0))
|
|
sample_fps = round(vr.get_avg_fps() / 1)
|
|
frame_idx = list(range(0, len(vr), sample_fps))
|
|
if len(frame_idx) > self.MAX_NUM_FRAMES:
|
|
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
|
|
frames = vr.get_batch(frame_idx).asnumpy()
|
|
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
|
print('num frames:', len(frames))
|
|
return frames
|
|
|
|
|
|
def process_video(self, video_data, object_name):
|
|
if not video_data:
|
|
raise ValueError(f"Empty video data for {object_name}")
|
|
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
|
|
frames = self.encode_video(video_data)
|
|
question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior."
|
|
msgs = [
|
|
{'role': 'user', 'content': frames + [question]},
|
|
]
|
|
|
|
params = {
|
|
"use_image_id": False,
|
|
"max_slice_nums": 1
|
|
}
|
|
answer = self.model.chat(
|
|
image=frames, # 直接传递 frames
|
|
msgs=msgs,
|
|
tokenizer=self.tokenizer,
|
|
max_length=512,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
**params
|
|
)
|
|
extracted_info = self.extract_info(answer)
|
|
|
|
return {
|
|
"original_answer": answer,
|
|
"extracted_info": extracted_info,
|
|
"num_frames": len(frames),
|
|
# "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
# "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S")
|
|
}
|
|
|
|
def process_image(self, image_data, object_name):
|
|
image = Image.open(io.BytesIO(image_data))
|
|
question = "描述这张图片,包括场景、人物数量和行为等细节。"
|
|
msgs = [
|
|
{'role': 'user', 'content': [image] + [question]},
|
|
]
|
|
|
|
params = {
|
|
"use_image_id": False,
|
|
"max_slice_nums": 1
|
|
}
|
|
|
|
answer = self.model.chat(
|
|
image=None,
|
|
msgs=msgs,
|
|
tokenizer=self.tokenizer,
|
|
max_length=512,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
**params
|
|
)
|
|
|
|
extracted_info = self.extract_info(answer)
|
|
|
|
return {
|
|
"original_answer": answer,
|
|
"extracted_info": extracted_info,
|
|
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
}
|
|
|
|
@staticmethod
|
|
def extract_time_from_filename(object_name):
|
|
filename = os.path.basename(object_name)
|
|
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
|
|
|
|
try:
|
|
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
|
|
end_time = start_time + timedelta(seconds=10)
|
|
return start_time, end_time
|
|
except ValueError:
|
|
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
|
|
return datetime.now(), datetime.now() + timedelta(seconds=10)
|
|
|
|
@staticmethod
|
|
def extract_info(answer):
|
|
info = {
|
|
"environment": None,
|
|
"num_people": None,
|
|
"actions": [],
|
|
"interactions": [],
|
|
"objects": [],
|
|
"furniture": []
|
|
}
|
|
|
|
environments = ["办公室", "室内", "室外", "会议室"]
|
|
for env in environments:
|
|
if env in answer.lower():
|
|
info["environment"] = env
|
|
break
|
|
|
|
people_patterns = [
|
|
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
|
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
|
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
|
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
|
r'(男|女)(性|生|士)',
|
|
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
|
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
|
r'(群众|民众|大众|公众)',
|
|
r'(男女|老少|老幼|大人|小孩)'
|
|
]
|
|
for pattern in people_patterns:
|
|
match = re.search(pattern, answer)
|
|
if match:
|
|
if match.group(1).isdigit():
|
|
info["num_people"] = int(match.group(1))
|
|
elif match.group(1) in ['一个', '一']:
|
|
info["num_people"] = 1
|
|
else:
|
|
num_word_to_digit = {
|
|
'二': 2, '三': 3, '四': 4, '五': 5,
|
|
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
|
}
|
|
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
|
break
|
|
|
|
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
|
for action in actions:
|
|
if action in answer:
|
|
info["actions"].append(action)
|
|
|
|
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
|
for interaction in interactions:
|
|
if interaction in answer:
|
|
info["interactions"].append(interaction)
|
|
|
|
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
|
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
|
|
|
for obj in objects:
|
|
if obj in answer:
|
|
info["objects"].append(obj)
|
|
|
|
for item in furniture:
|
|
if item in answer:
|
|
info["furniture"].append(item)
|
|
|
|
return info
|
|
|
|
# 初始化 MediaAnalysisSystem
|
|
media_analysis_system = MediaAnalysisSystem(model, tokenizer)
|
|
|
|
async def process_file(file: UploadFile, file_type: str):
|
|
content = await file.read()
|
|
# 获取原始文件的后缀
|
|
original_extension = os.path.splitext(file.filename)[1]
|
|
|
|
# 生成新的文件名,包含 UUID 和原始后缀
|
|
filename = f"cpm_{uuid.uuid4()}{original_extension}"
|
|
file_path = os.path.join(UPLOAD_DIR, filename)
|
|
with open(file_path, "wb") as f:
|
|
f.write(content)
|
|
|
|
producer.send(KAFKA_TOPIC, json.dumps({
|
|
"filename": filename,
|
|
"type": file_type
|
|
}).encode('utf-8'))
|
|
|
|
redis_key = f"{file_type}_result:{filename}"
|
|
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
|
|
|
return {"message": f"{file_type.capitalize()} uploaded and queued for processing", "filename": filename}
|
|
|
|
@cpm_app.post("/upload")
|
|
async def upload_file(file: UploadFile = File(...)):
|
|
try:
|
|
file_type = "image" if file.content_type.startswith("image") else "video"
|
|
return await process_file(file, file_type)
|
|
except Exception as e:
|
|
return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
|
|
@cpm_app.post("/analyze_video")
|
|
async def analyze_video(file: UploadFile = File(...)):
|
|
try:
|
|
return await process_file(file, "video")
|
|
except Exception as e:
|
|
return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
|
|
@cpm_app.post("/analyze_image")
|
|
async def analyze_image(file: UploadFile = File(...)):
|
|
try:
|
|
return await process_file(file, "image")
|
|
except Exception as e:
|
|
return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
|
|
def process_task():
|
|
for message in consumer:
|
|
try:
|
|
if isinstance(message.value, dict):
|
|
task = message.value
|
|
else:
|
|
task = json.loads(message.value.decode('utf-8'))
|
|
|
|
filename = task['filename']
|
|
file_type = task['type']
|
|
|
|
file_path = os.path.join(UPLOAD_DIR, filename)
|
|
|
|
redis_key = f"{file_type}_result:{filename}"
|
|
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
|
|
|
with open(file_path, 'rb') as f:
|
|
file_data = f.read()
|
|
|
|
if file_type == "video":
|
|
result = media_analysis_system.process_video(file_data, filename)
|
|
elif file_type == "image":
|
|
result = media_analysis_system.process_image(file_data, filename)
|
|
|
|
# 保存结果到 JSON 文件
|
|
result_file_path = os.path.join(RESULT_DIR, f"{filename}.json")
|
|
with open(result_file_path, 'w', encoding='utf-8') as f:
|
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
|
|
# 将结果存储在 Redis 中
|
|
redis_client.set(redis_key, json.dumps({
|
|
"status": "completed",
|
|
"result": result
|
|
}))
|
|
|
|
except Exception as e:
|
|
print(f"Error processing task: {str(e)}")
|
|
if 'filename' in locals() and 'file_type' in locals():
|
|
redis_key = f"{file_type}_result:{filename}"
|
|
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
|
else:
|
|
print("Error occurred before task details were extracted")
|
|
|
|
@cpm_app.get("/result/{filename}")
|
|
async def get_result(filename: str):
|
|
for file_type in ["video", "image"]:
|
|
redis_key = f"{file_type}_result:{filename}"
|
|
result = redis_client.get(redis_key)
|
|
if result:
|
|
result_json = json.loads(result)
|
|
|
|
if result_json.get("status") == "queued":
|
|
return {"status": "queued", "message": "Your request is in the queue and will be processed soon."}
|
|
elif result_json.get("status") == "processing":
|
|
return {"status": "processing", "message": "Your request is being processed."}
|
|
else:
|
|
return result_json
|
|
|
|
raise HTTPException(status_code=404, detail="Result not found")
|
|
|
|
async def listen_redis_changes():
|
|
pubsub = redis_client.pubsub()
|
|
pubsub.psubscribe('__keyspace@5__:*_result:*')
|
|
|
|
for message in pubsub.listen():
|
|
if message['type'] == 'pmessage':
|
|
key = message['channel'].decode('utf-8').split(':')[-1]
|
|
print(f"Key changed: {key}")
|
|
|
|
if __name__ == "__main__":
|
|
# 在后台线程中启动Kafka消费者
|
|
consumer_thread = threading.Thread(target=process_task, daemon=True)
|
|
consumer_thread.start()
|
|
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=7000) |