518 lines
18 KiB
Python
518 lines
18 KiB
Python
import os
|
|
import json
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, Header
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.security import APIKeyHeader
|
|
from kafka import KafkaProducer, KafkaConsumer
|
|
from transformers import AutoTokenizer, AutoModel
|
|
from decord import VideoReader, cpu
|
|
from PIL import Image
|
|
from redis import Redis
|
|
import io
|
|
import re
|
|
import torch
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
import threading
|
|
|
|
app = FastAPI()
|
|
cpm_app = FastAPI()
|
|
app.mount("/cpm", cpm_app)
|
|
|
|
# CORS设置
|
|
ALLOWED_ORIGINS = ['https://beta.obscura.work']
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=ALLOWED_ORIGINS,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# 配置
|
|
MODEL_PATH = "/home/zydi/worker_sys/OpenBMB/MiniCPM-V-2_6"
|
|
KAFKA_BROKER = "222.186.10.253:9092"
|
|
KAFKA_TOPIC = "cpm"
|
|
KAFKA_GROUP_ID = "cpm_group"
|
|
|
|
REDIS_HOST = "222.186.10.253"
|
|
REDIS_PORT = 6379
|
|
REDIS_PASSWORD = "Obscura@2024"
|
|
REDIS_DB = 5
|
|
REDIS_API_DB = 12
|
|
REDIS_API_USAGE_DB = 13
|
|
|
|
|
|
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
|
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
|
MAX_FILE_AGE = timedelta(hours=1)
|
|
|
|
# 确保目录存在
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|
os.makedirs(RESULT_DIR, exist_ok=True)
|
|
|
|
# 初始化 Kafka
|
|
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
|
consumer = KafkaConsumer(
|
|
KAFKA_TOPIC,
|
|
bootstrap_servers=[KAFKA_BROKER],
|
|
group_id=KAFKA_GROUP_ID,
|
|
auto_offset_reset='earliest',
|
|
enable_auto_commit=True,
|
|
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
|
)
|
|
|
|
# 初始化 Redis
|
|
redis_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=REDIS_DB
|
|
)
|
|
|
|
redis_api_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=REDIS_API_DB
|
|
)
|
|
|
|
redis_api_usage_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=REDIS_API_USAGE_DB
|
|
)
|
|
|
|
# 添加API密钥验证
|
|
API_KEY_NAME = "X-API-Key"
|
|
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
|
|
|
async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)):
|
|
if api_key is None:
|
|
raise HTTPException(status_code=400, detail="API密钥缺失")
|
|
return api_key
|
|
|
|
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
|
redis_key = f"api_key:{api_key}"
|
|
|
|
api_key_info = redis_api_client.hgetall(redis_key)
|
|
|
|
if not api_key_info:
|
|
return None
|
|
|
|
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
|
|
|
if api_key_info.get('is_active') != '1':
|
|
return None
|
|
|
|
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
|
if datetime.now(timezone.utc) > expires_at:
|
|
return None
|
|
|
|
usage_info = redis_api_usage_client.hgetall(redis_key)
|
|
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
|
|
|
return {
|
|
**api_key_info,
|
|
**usage_info
|
|
}
|
|
|
|
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
|
redis_key = f"api_key:{api_key}"
|
|
current_time = datetime.now(timezone.utc).isoformat()
|
|
|
|
pipe = redis_api_usage_client.pipeline()
|
|
|
|
# 更新总的token使用量
|
|
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
|
pipe.hset(redis_key, "last_used_at", current_time)
|
|
|
|
# 更新特定模型的token使用量
|
|
model_tokens_field = f"{model_name}_tokens_used"
|
|
model_last_used_field = f"{model_name}_last_used_at"
|
|
|
|
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
|
pipe.hset(redis_key, model_last_used_field, current_time)
|
|
|
|
pipe.execute()
|
|
|
|
def calculate_tokens(file_path: str, file_type: str) -> int:
|
|
base_tokens = 0
|
|
|
|
try:
|
|
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
|
|
|
# 基础token:每MB文件大小消耗10个token
|
|
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
|
|
|
if file_type == "image":
|
|
img = Image.open(file_path)
|
|
width, height = img.size
|
|
pixel_count = width * height
|
|
|
|
# 图片token:每100个像素额外消耗5个token
|
|
image_tokens = int((pixel_count / 10000) * 5)
|
|
|
|
base_tokens += image_tokens
|
|
|
|
elif file_type == "video":
|
|
vr = VideoReader(file_path)
|
|
fps = vr.get_avg_fps()
|
|
frame_count = len(vr)
|
|
width, height = vr[0].shape[1], vr[0].shape[0]
|
|
|
|
pixel_count = width * height * frame_count
|
|
duration = frame_count / fps # 视频时长(秒)
|
|
|
|
# 视频token:每100万像素每秒额外消耗1个token
|
|
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
|
|
|
base_tokens += video_tokens
|
|
|
|
return max(1, base_tokens) # 确保至少返回1个token
|
|
except Exception as e:
|
|
print(f"计算token时出错: {str(e)}")
|
|
return 1 # 出错时返回默认值1
|
|
|
|
# 设置 GPU 设备
|
|
torch.cuda.set_device(0)
|
|
|
|
# 初始化模型
|
|
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
|
model = model.half().cuda().eval()
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
|
|
|
class MediaAnalysisSystem:
|
|
def __init__(self, model, tokenizer):
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.device = torch.device("cuda:0")
|
|
self.model = self.model.to(self.device)
|
|
self.MAX_NUM_FRAMES = 16
|
|
|
|
def encode_video(self, video_data):
|
|
def uniform_sample(l, n):
|
|
gap = len(l) / n
|
|
return [l[int(i * gap + gap / 2)] for i in range(n)]
|
|
|
|
video_file = io.BytesIO(video_data)
|
|
vr = VideoReader(video_file, ctx=cpu(0))
|
|
sample_fps = round(vr.get_avg_fps() / 1)
|
|
frame_idx = list(range(0, len(vr), sample_fps))
|
|
if len(frame_idx) > self.MAX_NUM_FRAMES:
|
|
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
|
|
frames = vr.get_batch(frame_idx).asnumpy()
|
|
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
|
print('num frames:', len(frames))
|
|
return frames
|
|
|
|
|
|
def process_video(self, video_data, object_name):
|
|
if not video_data:
|
|
raise ValueError(f"Empty video data for {object_name}")
|
|
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
|
|
frames = self.encode_video(video_data)
|
|
question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior."
|
|
msgs = [
|
|
{'role': 'user', 'content': frames + [question]},
|
|
]
|
|
|
|
params = {
|
|
"use_image_id": False,
|
|
"max_slice_nums": 1
|
|
}
|
|
answer = self.model.chat(
|
|
image=frames, # 直接传递 frames
|
|
msgs=msgs,
|
|
tokenizer=self.tokenizer,
|
|
max_length=512,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
**params
|
|
)
|
|
extracted_info = self.extract_info(answer)
|
|
|
|
return {
|
|
"original_answer": answer,
|
|
"extracted_info": extracted_info,
|
|
"num_frames": len(frames),
|
|
# "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
# "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S")
|
|
}
|
|
|
|
def process_image(self, image_data, object_name):
|
|
image = Image.open(io.BytesIO(image_data))
|
|
question = "描述这张图片,包括场景、人物数量和行为等细节。"
|
|
msgs = [
|
|
{'role': 'user', 'content': [image] + [question]},
|
|
]
|
|
|
|
params = {
|
|
"use_image_id": False,
|
|
"max_slice_nums": 1
|
|
}
|
|
|
|
answer = self.model.chat(
|
|
image=None,
|
|
msgs=msgs,
|
|
tokenizer=self.tokenizer,
|
|
max_length=512,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
**params
|
|
)
|
|
|
|
extracted_info = self.extract_info(answer)
|
|
|
|
return {
|
|
"original_answer": answer,
|
|
"extracted_info": extracted_info,
|
|
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
}
|
|
|
|
@staticmethod
|
|
def extract_time_from_filename(object_name):
|
|
filename = os.path.basename(object_name)
|
|
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
|
|
|
|
try:
|
|
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
|
|
end_time = start_time + timedelta(seconds=10)
|
|
return start_time, end_time
|
|
except ValueError:
|
|
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
|
|
return datetime.now(), datetime.now() + timedelta(seconds=10)
|
|
|
|
@staticmethod
|
|
def extract_info(answer):
|
|
info = {
|
|
"environment": None,
|
|
"num_people": None,
|
|
"actions": [],
|
|
"interactions": [],
|
|
"objects": [],
|
|
"furniture": []
|
|
}
|
|
|
|
environments = ["办公室", "室内", "室外", "会议室"]
|
|
for env in environments:
|
|
if env in answer.lower():
|
|
info["environment"] = env
|
|
break
|
|
|
|
people_patterns = [
|
|
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
|
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
|
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
|
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
|
r'(男|女)(性|生|士)',
|
|
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
|
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
|
r'(群众|民众|大众|公众)',
|
|
r'(男女|老少|老幼|大人|小孩)'
|
|
]
|
|
for pattern in people_patterns:
|
|
match = re.search(pattern, answer)
|
|
if match:
|
|
if match.group(1).isdigit():
|
|
info["num_people"] = int(match.group(1))
|
|
elif match.group(1) in ['一个', '一']:
|
|
info["num_people"] = 1
|
|
else:
|
|
num_word_to_digit = {
|
|
'二': 2, '三': 3, '四': 4, '五': 5,
|
|
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
|
}
|
|
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
|
break
|
|
|
|
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
|
for action in actions:
|
|
if action in answer:
|
|
info["actions"].append(action)
|
|
|
|
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
|
for interaction in interactions:
|
|
if interaction in answer:
|
|
info["interactions"].append(interaction)
|
|
|
|
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
|
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
|
|
|
for obj in objects:
|
|
if obj in answer:
|
|
info["objects"].append(obj)
|
|
|
|
for item in furniture:
|
|
if item in answer:
|
|
info["furniture"].append(item)
|
|
|
|
return info
|
|
|
|
# 初始化 MediaAnalysisSystem
|
|
media_analysis_system = MediaAnalysisSystem(model, tokenizer)
|
|
|
|
async def process_file(file: UploadFile, file_type: str, api_key: str):
|
|
content = await file.read()
|
|
original_extension = os.path.splitext(file.filename)[1]
|
|
|
|
filename = f"cpm_{uuid.uuid4()}{original_extension}"
|
|
file_path = os.path.join(UPLOAD_DIR, filename)
|
|
with open(file_path, "wb") as f:
|
|
f.write(content)
|
|
|
|
# 计算token
|
|
tokens_required = calculate_tokens(file_path, file_type)
|
|
|
|
# 检查并更新token使用量
|
|
usage_key = f"api_key:{api_key}"
|
|
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
|
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
|
|
|
if tokens_used + tokens_required > total_tokens:
|
|
raise HTTPException(status_code=403, detail="Token 余额不足")
|
|
|
|
# 更新token使用量
|
|
model_name = "MiniCPM-V-2_6"
|
|
await update_token_usage(api_key, tokens_required, model_name)
|
|
|
|
|
|
producer.send(KAFKA_TOPIC, json.dumps({
|
|
"filename": filename,
|
|
"type": file_type
|
|
}).encode('utf-8'))
|
|
|
|
redis_key = f"{file_type}_result:{filename}"
|
|
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
|
|
|
# 获取更新后的 token 使用情况
|
|
updated_api_key_info = await verify_api_key(api_key)
|
|
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
|
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
|
|
|
return JSONResponse(content={
|
|
"message": "文件已上传并排队等待处理",
|
|
"filename": filename,
|
|
"tokens_used": tokens_required,
|
|
"total_tokens_used": new_tokens_used,
|
|
f"{model_name}_tokens_used": model_tokens_used,
|
|
"tokens_remaining": total_tokens - new_tokens_used
|
|
})
|
|
|
|
@cpm_app.post("/upload")
|
|
async def upload_file(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
|
|
api_key_info = await verify_api_key(api_key)
|
|
if not api_key_info:
|
|
raise HTTPException(status_code=403, detail="无效的API密钥")
|
|
|
|
try:
|
|
file_type = "image" if file.content_type.startswith("image") else "video"
|
|
return await process_file(file, file_type, api_key)
|
|
except Exception as e:
|
|
return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
|
|
|
|
@cpm_app.post("/analyze_video")
|
|
async def analyze_video(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
|
|
api_key_info = await verify_api_key(api_key)
|
|
if not api_key_info:
|
|
raise HTTPException(status_code=403, detail="无效的API密钥")
|
|
|
|
try:
|
|
return await process_file(file, "video", api_key)
|
|
except Exception as e:
|
|
return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
|
|
@cpm_app.post("/analyze_image")
|
|
async def analyze_image(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
|
|
api_key_info = await verify_api_key(api_key)
|
|
if not api_key_info:
|
|
raise HTTPException(status_code=403, detail="无效的API密钥")
|
|
|
|
try:
|
|
return await process_file(file, "image", api_key)
|
|
except Exception as e:
|
|
return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
|
|
|
|
def process_task():
|
|
for message in consumer:
|
|
try:
|
|
if isinstance(message.value, dict):
|
|
task = message.value
|
|
else:
|
|
task = json.loads(message.value.decode('utf-8'))
|
|
|
|
filename = task['filename']
|
|
file_type = task['type']
|
|
|
|
file_path = os.path.join(UPLOAD_DIR, filename)
|
|
|
|
redis_key = f"{file_type}_result:{filename}"
|
|
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
|
|
|
with open(file_path, 'rb') as f:
|
|
file_data = f.read()
|
|
|
|
if file_type == "video":
|
|
result = media_analysis_system.process_video(file_data, filename)
|
|
elif file_type == "image":
|
|
result = media_analysis_system.process_image(file_data, filename)
|
|
|
|
# 保存结果到 JSON 文件
|
|
result_file_path = os.path.join(RESULT_DIR, f"{filename}.json")
|
|
with open(result_file_path, 'w', encoding='utf-8') as f:
|
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
|
|
# 将结果存储在 Redis 中
|
|
redis_client.set(redis_key, json.dumps({
|
|
"status": "completed",
|
|
"result": result
|
|
}))
|
|
|
|
except Exception as e:
|
|
print(f"Error processing task: {str(e)}")
|
|
if 'filename' in locals() and 'file_type' in locals():
|
|
redis_key = f"{file_type}_result:{filename}"
|
|
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
|
else:
|
|
print("Error occurred before task details were extracted")
|
|
|
|
@cpm_app.get("/result/{filename}")
|
|
async def get_result(filename: str):
|
|
for file_type in ["video", "image"]:
|
|
redis_key = f"{file_type}_result:{filename}"
|
|
result = redis_client.get(redis_key)
|
|
if result:
|
|
result_json = json.loads(result)
|
|
|
|
if result_json.get("status") == "queued":
|
|
return {"status": "queued", "message": "Your request is in the queue and will be processed soon."}
|
|
elif result_json.get("status") == "processing":
|
|
return {"status": "processing", "message": "Your request is being processed."}
|
|
else:
|
|
return result_json
|
|
|
|
raise HTTPException(status_code=404, detail="Result not found")
|
|
|
|
async def listen_redis_changes():
|
|
pubsub = redis_client.pubsub()
|
|
pubsub.psubscribe('__keyspace@5__:*_result:*')
|
|
|
|
for message in pubsub.listen():
|
|
if message['type'] == 'pmessage':
|
|
key = message['channel'].decode('utf-8').split(':')[-1]
|
|
print(f"Key changed: {key}")
|
|
|
|
if __name__ == "__main__":
|
|
# 在后台线程中启动Kafka消费者
|
|
consumer_thread = threading.Thread(target=process_task, daemon=True)
|
|
consumer_thread.start()
|
|
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=7000) |