Files
api/api_old/before/cpm_key.py
T
2025-01-12 06:15:15 +00:00

518 lines
18 KiB
Python

import os
import json
import uuid
from datetime import datetime, timedelta, timezone
from fastapi import FastAPI, HTTPException, UploadFile, File, Depends, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.security import APIKeyHeader
from kafka import KafkaProducer, KafkaConsumer
from transformers import AutoTokenizer, AutoModel
from decord import VideoReader, cpu
from PIL import Image
from redis import Redis
import io
import re
import torch
import asyncio
from contextlib import asynccontextmanager
import threading
app = FastAPI()
cpm_app = FastAPI()
app.mount("/cpm", cpm_app)
# CORS设置
ALLOWED_ORIGINS = ['https://beta.obscura.work']
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 配置
MODEL_PATH = "/home/zydi/worker_sys/OpenBMB/MiniCPM-V-2_6"
KAFKA_BROKER = "222.186.10.253:9092"
KAFKA_TOPIC = "cpm"
KAFKA_GROUP_ID = "cpm_group"
REDIS_HOST = "222.186.10.253"
REDIS_PORT = 6379
REDIS_PASSWORD = "Obscura@2024"
REDIS_DB = 5
REDIS_API_DB = 12
REDIS_API_USAGE_DB = 13
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
MAX_FILE_AGE = timedelta(hours=1)
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 初始化 Kafka
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
group_id=KAFKA_GROUP_ID,
auto_offset_reset='earliest',
enable_auto_commit=True,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
redis_api_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_API_DB
)
redis_api_usage_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_API_USAGE_DB
)
# 添加API密钥验证
API_KEY_NAME = "X-API-Key"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)):
if api_key is None:
raise HTTPException(status_code=400, detail="API密钥缺失")
return api_key
async def verify_api_key(api_key: str = Depends(get_api_key)):
redis_key = f"api_key:{api_key}"
api_key_info = redis_api_client.hgetall(redis_key)
if not api_key_info:
return None
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
if api_key_info.get('is_active') != '1':
return None
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
if datetime.now(timezone.utc) > expires_at:
return None
usage_info = redis_api_usage_client.hgetall(redis_key)
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
return {
**api_key_info,
**usage_info
}
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
redis_key = f"api_key:{api_key}"
current_time = datetime.now(timezone.utc).isoformat()
pipe = redis_api_usage_client.pipeline()
# 更新总的token使用量
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
pipe.hset(redis_key, "last_used_at", current_time)
# 更新特定模型的token使用量
model_tokens_field = f"{model_name}_tokens_used"
model_last_used_field = f"{model_name}_last_used_at"
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
pipe.hset(redis_key, model_last_used_field, current_time)
pipe.execute()
def calculate_tokens(file_path: str, file_type: str) -> int:
base_tokens = 0
try:
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
# 基础token:每MB文件大小消耗10个token
base_tokens = int((file_size / (1024 * 1024)) * 10)
if file_type == "image":
img = Image.open(file_path)
width, height = img.size
pixel_count = width * height
# 图片token:每100个像素额外消耗5个token
image_tokens = int((pixel_count / 10000) * 5)
base_tokens += image_tokens
elif file_type == "video":
vr = VideoReader(file_path)
fps = vr.get_avg_fps()
frame_count = len(vr)
width, height = vr[0].shape[1], vr[0].shape[0]
pixel_count = width * height * frame_count
duration = frame_count / fps # 视频时长(秒)
# 视频token:每100万像素每秒额外消耗1个token
video_tokens = int((pixel_count / 10000) * (duration / 60))
base_tokens += video_tokens
return max(1, base_tokens) # 确保至少返回1个token
except Exception as e:
print(f"计算token时出错: {str(e)}")
return 1 # 出错时返回默认值1
# 设置 GPU 设备
torch.cuda.set_device(0)
# 初始化模型
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = model.half().cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
class MediaAnalysisSystem:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.device = torch.device("cuda:0")
self.model = self.model.to(self.device)
self.MAX_NUM_FRAMES = 16
def encode_video(self, video_data):
def uniform_sample(l, n):
gap = len(l) / n
return [l[int(i * gap + gap / 2)] for i in range(n)]
video_file = io.BytesIO(video_data)
vr = VideoReader(video_file, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = list(range(0, len(vr), sample_fps))
if len(frame_idx) > self.MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
print('num frames:', len(frames))
return frames
def process_video(self, video_data, object_name):
if not video_data:
raise ValueError(f"Empty video data for {object_name}")
print(f"Processing video: {object_name}, data size: {len(video_data)} bytes")
frames = self.encode_video(video_data)
question = "Describe the video in as much detail as possible in Chinese, including the setting, clear number of people, and changes in behavior."
msgs = [
{'role': 'user', 'content': frames + [question]},
]
params = {
"use_image_id": False,
"max_slice_nums": 1
}
answer = self.model.chat(
image=frames, # 直接传递 frames
msgs=msgs,
tokenizer=self.tokenizer,
max_length=512,
temperature=0.7,
top_p=0.9,
**params
)
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"num_frames": len(frames),
# "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
# "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S")
}
def process_image(self, image_data, object_name):
image = Image.open(io.BytesIO(image_data))
question = "描述这张图片,包括场景、人物数量和行为等细节。"
msgs = [
{'role': 'user', 'content': [image] + [question]},
]
params = {
"use_image_id": False,
"max_slice_nums": 1
}
answer = self.model.chat(
image=None,
msgs=msgs,
tokenizer=self.tokenizer,
max_length=512,
temperature=0.7,
top_p=0.9,
**params
)
extracted_info = self.extract_info(answer)
return {
"original_answer": answer,
"extracted_info": extracted_info,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
@staticmethod
def extract_time_from_filename(object_name):
filename = os.path.basename(object_name)
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
try:
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
end_time = start_time + timedelta(seconds=10)
return start_time, end_time
except ValueError:
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
return datetime.now(), datetime.now() + timedelta(seconds=10)
@staticmethod
def extract_info(answer):
info = {
"environment": None,
"num_people": None,
"actions": [],
"interactions": [],
"objects": [],
"furniture": []
}
environments = ["办公室", "室内", "室外", "会议室"]
for env in environments:
if env in answer.lower():
info["environment"] = env
break
people_patterns = [
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
r'\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
r'(男|女)(性|生|士)',
r'(成年|未成年|青少年|老年)\s*(人|群体)',
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
r'(群众|民众|大众|公众)',
r'(男女|老少|老幼|大人|小孩)'
]
for pattern in people_patterns:
match = re.search(pattern, answer)
if match:
if match.group(1).isdigit():
info["num_people"] = int(match.group(1))
elif match.group(1) in ['一个', '']:
info["num_people"] = 1
else:
num_word_to_digit = {
'': 2, '': 3, '': 4, '': 5,
'': 6, '': 7, '': 8, '': 9, '': 10
}
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
break
actions = ["", "", "摔倒", "跳舞", "转身", "", "", "倒下", "躺下", "转身", "跳跃", "", "", "", "说话"]
for action in actions:
if action in answer:
info["actions"].append(action)
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
for interaction in interactions:
if interaction in answer:
info["interactions"].append(interaction)
objects = ["水瓶", "办公用品", "文件", "电脑"]
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "", "沙发"]
for obj in objects:
if obj in answer:
info["objects"].append(obj)
for item in furniture:
if item in answer:
info["furniture"].append(item)
return info
# 初始化 MediaAnalysisSystem
media_analysis_system = MediaAnalysisSystem(model, tokenizer)
async def process_file(file: UploadFile, file_type: str, api_key: str):
content = await file.read()
original_extension = os.path.splitext(file.filename)[1]
filename = f"cpm_{uuid.uuid4()}{original_extension}"
file_path = os.path.join(UPLOAD_DIR, filename)
with open(file_path, "wb") as f:
f.write(content)
# 计算token
tokens_required = calculate_tokens(file_path, file_type)
# 检查并更新token使用量
usage_key = f"api_key:{api_key}"
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
if tokens_used + tokens_required > total_tokens:
raise HTTPException(status_code=403, detail="Token 余额不足")
# 更新token使用量
model_name = "MiniCPM-V-2_6"
await update_token_usage(api_key, tokens_required, model_name)
producer.send(KAFKA_TOPIC, json.dumps({
"filename": filename,
"type": file_type
}).encode('utf-8'))
redis_key = f"{file_type}_result:{filename}"
redis_client.set(redis_key, json.dumps({"status": "queued"}))
# 获取更新后的 token 使用情况
updated_api_key_info = await verify_api_key(api_key)
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
return JSONResponse(content={
"message": "文件已上传并排队等待处理",
"filename": filename,
"tokens_used": tokens_required,
"total_tokens_used": new_tokens_used,
f"{model_name}_tokens_used": model_tokens_used,
"tokens_remaining": total_tokens - new_tokens_used
})
@cpm_app.post("/upload")
async def upload_file(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
api_key_info = await verify_api_key(api_key)
if not api_key_info:
raise HTTPException(status_code=403, detail="无效的API密钥")
try:
file_type = "image" if file.content_type.startswith("image") else "video"
return await process_file(file, file_type, api_key)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@cpm_app.post("/analyze_video")
async def analyze_video(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
api_key_info = await verify_api_key(api_key)
if not api_key_info:
raise HTTPException(status_code=403, detail="无效的API密钥")
try:
return await process_file(file, "video", api_key)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@cpm_app.post("/analyze_image")
async def analyze_image(file: UploadFile = File(...), api_key: str = Depends(get_api_key)):
api_key_info = await verify_api_key(api_key)
if not api_key_info:
raise HTTPException(status_code=403, detail="无效的API密钥")
try:
return await process_file(file, "image", api_key)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
def process_task():
for message in consumer:
try:
if isinstance(message.value, dict):
task = message.value
else:
task = json.loads(message.value.decode('utf-8'))
filename = task['filename']
file_type = task['type']
file_path = os.path.join(UPLOAD_DIR, filename)
redis_key = f"{file_type}_result:{filename}"
redis_client.set(redis_key, json.dumps({"status": "processing"}))
with open(file_path, 'rb') as f:
file_data = f.read()
if file_type == "video":
result = media_analysis_system.process_video(file_data, filename)
elif file_type == "image":
result = media_analysis_system.process_image(file_data, filename)
# 保存结果到 JSON 文件
result_file_path = os.path.join(RESULT_DIR, f"{filename}.json")
with open(result_file_path, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
# 将结果存储在 Redis 中
redis_client.set(redis_key, json.dumps({
"status": "completed",
"result": result
}))
except Exception as e:
print(f"Error processing task: {str(e)}")
if 'filename' in locals() and 'file_type' in locals():
redis_key = f"{file_type}_result:{filename}"
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
else:
print("Error occurred before task details were extracted")
@cpm_app.get("/result/{filename}")
async def get_result(filename: str):
for file_type in ["video", "image"]:
redis_key = f"{file_type}_result:{filename}"
result = redis_client.get(redis_key)
if result:
result_json = json.loads(result)
if result_json.get("status") == "queued":
return {"status": "queued", "message": "Your request is in the queue and will be processed soon."}
elif result_json.get("status") == "processing":
return {"status": "processing", "message": "Your request is being processed."}
else:
return result_json
raise HTTPException(status_code=404, detail="Result not found")
async def listen_redis_changes():
pubsub = redis_client.pubsub()
pubsub.psubscribe('__keyspace@5__:*_result:*')
for message in pubsub.listen():
if message['type'] == 'pmessage':
key = message['channel'].decode('utf-8').split(':')[-1]
print(f"Key changed: {key}")
if __name__ == "__main__":
# 在后台线程中启动Kafka消费者
consumer_thread = threading.Thread(target=process_task, daemon=True)
consumer_thread.start()
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7000)