419 lines
15 KiB
Python
419 lines
15 KiB
Python
# main.py
|
|
from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Security
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.security import APIKeyHeader
|
|
from kafka import KafkaProducer
|
|
from redis import Redis
|
|
import os
|
|
import json
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
import string
|
|
from decord import VideoReader
|
|
from PIL import Image
|
|
from fastapi.responses import FileResponse
|
|
import logging
|
|
from config import *
|
|
|
|
app = FastAPI()
|
|
v1_app = FastAPI()
|
|
app.mount("/v1", v1_app)
|
|
|
|
|
|
# CORS设置
|
|
# ALLOWED_ORIGINS = ['https://beta.obscura.work']
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# 配置
|
|
KAFKA_BROKER = KAFKA_BROKER
|
|
REDIS_HOST = REDIS_HOST
|
|
REDIS_PORT = REDIS_PORT
|
|
REDIS_PASSWORD = REDIS_PASSWORD
|
|
REDIS_DB = MAIN_REDIS_DB
|
|
REDIS_API_DB = REDIS_API_DB
|
|
REDIS_API_USAGE_DB = REDIS_API_USAGE_DB
|
|
UPLOAD_DIR = UPLOAD_DIR
|
|
RESULT_DIR = RESULT_DIR
|
|
MAX_FILE_AGE = timedelta(hours=1)
|
|
|
|
# 确保目录存在
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|
os.makedirs(RESULT_DIR, exist_ok=True)
|
|
|
|
# 定义支持的任务类型
|
|
KAFKA_TOPICS = {
|
|
'pose': 'pose',
|
|
'mediapipe': 'mediapipe',
|
|
'qwenvl': 'qwenvl',
|
|
'yolo': 'yolo',
|
|
'fall': 'fall',
|
|
'face': 'face',
|
|
'cpm': 'cpm'
|
|
}
|
|
|
|
TASK_TYPES = list(KAFKA_TOPICS.keys())
|
|
|
|
|
|
# 初始化 Kafka Producer
|
|
producer = KafkaProducer(
|
|
bootstrap_servers=[KAFKA_BROKER],
|
|
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
|
)
|
|
|
|
# 初始化 Redis
|
|
redis_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=REDIS_DB
|
|
)
|
|
|
|
redis_api_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=REDIS_API_DB
|
|
)
|
|
|
|
redis_api_usage_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=REDIS_API_USAGE_DB
|
|
)
|
|
redis_pose_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['pose']['redis_db'])
|
|
redis_cpm_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm']['redis_db'])
|
|
redis_yolo_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['yolo']['redis_db'])
|
|
redis_face_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['face']['redis_db'])
|
|
redis_fall_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['fall']['redis_db'])
|
|
redis_mediapipe_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['mediapipe']['redis_db'])
|
|
redis_qwenvl_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl']['redis_db'])
|
|
|
|
@v1_app.get('/favicon.ico', include_in_schema=False)
|
|
async def favicon():
|
|
file_name = "favicon.ico"
|
|
file_path = os.path.join(app.root_path, "static", file_name)
|
|
if os.path.isfile(file_path):
|
|
return FileResponse(file_path)
|
|
else:
|
|
return {"message": "Favicon not found"}, 404
|
|
|
|
# 定义API密钥头部
|
|
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
|
|
|
# 定义 base62 字符集
|
|
BASE62 = string.digits + string.ascii_letters
|
|
|
|
# 验证API密钥的函数
|
|
async def get_api_key(api_key: str = Security(api_key_header)):
|
|
if api_key and api_key.startswith("Bearer "):
|
|
key = api_key.split(" ")[1]
|
|
if key.startswith("obs-"):
|
|
return key
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="无效的API密钥",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
|
logging.info(f"验证API密钥: {api_key}")
|
|
redis_key = f"api_key:{api_key}"
|
|
|
|
api_key_info = redis_api_client.hgetall(redis_key)
|
|
|
|
if not api_key_info:
|
|
logging.warning(f"API密钥不存在: {api_key}")
|
|
raise HTTPException(status_code=401, detail="无效的API密钥")
|
|
|
|
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
|
|
|
if api_key_info.get('is_active') != '1':
|
|
logging.warning(f"API密钥已停用: {api_key}")
|
|
raise HTTPException(status_code=401, detail="API密钥已停用")
|
|
|
|
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
|
if datetime.now(timezone.utc) > expires_at:
|
|
logging.warning(f"API密钥已过期: {api_key}")
|
|
raise HTTPException(status_code=401, detail="API密钥已过期")
|
|
|
|
usage_info = redis_api_usage_client.hgetall(redis_key)
|
|
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
|
|
|
logging.info(f"API密钥验证成功: {api_key}")
|
|
return {
|
|
"api_key": api_key,
|
|
**api_key_info,
|
|
**usage_info
|
|
}
|
|
|
|
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
|
redis_key = f"api_key:{api_key}"
|
|
current_time = datetime.now(timezone.utc).isoformat()
|
|
|
|
pipe = redis_api_usage_client.pipeline()
|
|
|
|
# 更新总的token使用量
|
|
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
|
pipe.hset(redis_key, "last_used_at", current_time)
|
|
|
|
# 更新特定模型的token使用量
|
|
model_tokens_field = f"{model_name}_tokens_used"
|
|
model_last_used_field = f"{model_name}_last_used_at"
|
|
|
|
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
|
pipe.hset(redis_key, model_last_used_field, current_time)
|
|
|
|
pipe.execute()
|
|
|
|
def calculate_tokens(file_path: str, file_type: str) -> int:
|
|
base_tokens = 0
|
|
|
|
try:
|
|
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
|
|
|
# 基础token:每MB文件大小消耗10个token
|
|
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
|
|
|
if file_type == "image":
|
|
img = Image.open(file_path)
|
|
width, height = img.size
|
|
pixel_count = width * height
|
|
|
|
# 图片token:每100个像素额外消耗5个token
|
|
image_tokens = int((pixel_count / 10000) * 5)
|
|
|
|
base_tokens += image_tokens
|
|
|
|
elif file_type == "video":
|
|
vr = VideoReader(file_path)
|
|
fps = vr.get_avg_fps()
|
|
frame_count = len(vr)
|
|
width, height = vr[0].shape[1], vr[0].shape[0]
|
|
|
|
pixel_count = width * height * frame_count
|
|
duration = frame_count / fps # 视频时长(秒)
|
|
|
|
# 视频token:每100万像素每秒额外消耗1个token
|
|
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
|
|
|
base_tokens += video_tokens
|
|
|
|
return max(1, base_tokens) # 确保至少返回1个token
|
|
except Exception as e:
|
|
print(f"计算token时出错: {str(e)}")
|
|
return 1 # 出错时返回默认值1
|
|
|
|
|
|
|
|
async def upload_file(file: UploadFile, task_type: str, api_key_info: dict):
|
|
if task_type not in KAFKA_TOPICS:
|
|
raise HTTPException(status_code=400, detail="不支持的任务类型")
|
|
|
|
content = await file.read()
|
|
file_extension = os.path.splitext(file.filename)[1].lower()
|
|
new_filename = f"{uuid.uuid4()}{file_extension}"
|
|
|
|
file_path = os.path.join(UPLOAD_DIR, new_filename)
|
|
with open(file_path, "wb") as f:
|
|
f.write(content)
|
|
|
|
# 计算 token
|
|
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
|
tokens_required = calculate_tokens(file_path, file_type)
|
|
|
|
# 检查并更新 token 使用量
|
|
api_key = api_key_info['api_key']
|
|
usage_key = f"api_key:{api_key}"
|
|
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
|
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
|
|
|
if tokens_used + tokens_required > total_tokens:
|
|
raise HTTPException(status_code=403, detail="Token 余额不足")
|
|
|
|
# 更新 token 使用量
|
|
await update_token_usage(api_key, tokens_required, task_type)
|
|
|
|
# 创建任务记录
|
|
task_id = str(uuid.uuid4())
|
|
task_data = {
|
|
"task_id": task_id,
|
|
"filename": new_filename,
|
|
"file_type": file_type,
|
|
"task_type": task_type,
|
|
"status": "queued",
|
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
|
|
# 存储任务信息到Redis
|
|
redis_client.set(f"task:{task_id}", json.dumps(task_data))
|
|
logging.info(f"任务信息已存储到Redis: {task_id}")
|
|
|
|
# 发送任务到对应的Kafka主题
|
|
kafka_topic = KAFKA_TOPICS[task_type]
|
|
producer.send(kafka_topic, task_data)
|
|
logging.info(f"任务已发送到Kafka主题: {kafka_topic}")
|
|
|
|
# 获取更新后的 token 使用情况
|
|
updated_api_key_info = await verify_api_key(api_key)
|
|
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
|
model_tokens_used = int(updated_api_key_info.get(f"{task_type}_tokens_used", 0))
|
|
|
|
response_data = {
|
|
"message": "文件已上传并排队等待处理",
|
|
"task_id": task_id,
|
|
"tokens_used": tokens_required,
|
|
"total_tokens_used": new_tokens_used,
|
|
f"{task_type}_tokens_used": model_tokens_used,
|
|
"tokens_remaining": total_tokens - new_tokens_used
|
|
}
|
|
logging.info(f"上传文件完成: {task_id}")
|
|
return JSONResponse(content=response_data)
|
|
|
|
# 为每个任务类型创建单独的端点
|
|
@v1_app.post("/pose")
|
|
async def upload_pose(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
|
logging.info(f"收到 /pose端点的请求")
|
|
return await upload_file(file, task_type="pose", api_key_info=api_key_info)
|
|
|
|
@v1_app.post("/cpm")
|
|
async def upload_cpm(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
|
return await upload_file(file, task_type="cpm", api_key_info=api_key_info)
|
|
|
|
@v1_app.post("/qwenvl")
|
|
async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
|
return await upload_file(file, task_type="qwenvl", api_key_info=api_key_info)
|
|
|
|
@v1_app.post("/yolo")
|
|
async def upload_yolo(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
|
return await upload_file(file, task_type="yolo", api_key_info=api_key_info)
|
|
|
|
@v1_app.post("/fall")
|
|
async def upload_fall(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
|
return await upload_file(file, task_type="fall", api_key_info=api_key_info)
|
|
|
|
@v1_app.post("/face")
|
|
async def upload_face(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
|
logging.info(f"收到 /face 端点的请求")
|
|
return await upload_file(file, task_type="face", api_key_info=api_key_info)
|
|
|
|
@v1_app.post("/mediapipe")
|
|
async def upload_mediapipe(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
|
return await upload_file(file, task_type="mediapipe", api_key_info=api_key_info)
|
|
|
|
|
|
@v1_app.get("/result/{task_id}")
|
|
async def get_result(task_id: str, api_key: str = Depends(get_api_key)):
|
|
api_key_info = await verify_api_key(api_key)
|
|
if not api_key_info:
|
|
raise HTTPException(status_code=403, detail="无效的API密钥")
|
|
|
|
# 从 REDIS_DB (15) 获取任务状态
|
|
task_info = redis_client.hgetall(f"task:{task_id}")
|
|
if not task_info:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()}
|
|
|
|
if task_info['status'] != 'completed':
|
|
return {"status": task_info['status'], "message": "Task is not completed yet"}
|
|
|
|
result_type = task_info['result_type']
|
|
result_key = task_info['result_key']
|
|
|
|
# 根据任务类型选择相应的 Redis 客户端
|
|
redis_client_map = {
|
|
'pose': redis_pose_client,
|
|
'cpm': redis_cpm_client,
|
|
'yolo': redis_yolo_client,
|
|
'face': redis_face_client,
|
|
'fall': redis_fall_client,
|
|
'mediapipe': redis_mediapipe_client,
|
|
'qwenvl': redis_qwenvl_client
|
|
}
|
|
|
|
result_redis = redis_client_map.get(result_type)
|
|
if not result_redis:
|
|
raise HTTPException(status_code=400, detail="Unsupported result type")
|
|
|
|
result = result_redis.hgetall(result_key)
|
|
if not result:
|
|
raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found")
|
|
|
|
result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
|
|
|
|
# 将 result 字段解析为 JSON(如果存在)
|
|
if 'result' in result:
|
|
result['result'] = json.loads(result['result'])
|
|
|
|
return {
|
|
"status": "completed",
|
|
"result_type": result_type,
|
|
"result": result
|
|
}
|
|
|
|
@v1_app.get("/annotated/{task_id}")
|
|
async def get_annotated_image(task_id: str, api_key: str = Depends(get_api_key)):
|
|
api_key_info = await verify_api_key(api_key)
|
|
if not api_key_info:
|
|
raise HTTPException(status_code=403, detail="无效的API密钥")
|
|
|
|
# 从 REDIS_DB (15) 获取任务信息
|
|
task_info = redis_client.hgetall(f"task:{task_id}")
|
|
if not task_info:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()}
|
|
|
|
if task_info['status'] != 'completed':
|
|
raise HTTPException(status_code=400, detail="Task is not completed yet")
|
|
|
|
result_type = task_info.get('result_type')
|
|
result_key = task_info.get('result_key')
|
|
|
|
if not result_key:
|
|
raise HTTPException(status_code=404, detail="Result key not found")
|
|
|
|
if result_type in ['cpm', 'qwenvl']:
|
|
raise HTTPException(status_code=400, detail="Annotated image not available for this task type")
|
|
|
|
# 根据任务类型选择相应的 Redis 客户端
|
|
redis_client_map = {
|
|
'pose': redis_pose_client,
|
|
'yolo': redis_yolo_client,
|
|
'face': redis_face_client,
|
|
'fall': redis_fall_client,
|
|
'mediapipe': redis_mediapipe_client
|
|
}
|
|
|
|
result_redis = redis_client_map.get(result_type)
|
|
if not result_redis:
|
|
raise HTTPException(status_code=400, detail="Unsupported result type")
|
|
|
|
result = result_redis.hgetall(result_key)
|
|
if not result:
|
|
raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found")
|
|
|
|
result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
|
|
|
|
result_file = result.get('result_file')
|
|
if not result_file:
|
|
raise HTTPException(status_code=404, detail="Result file not found")
|
|
|
|
file_path = os.path.join(RESULT_DIR, result_file)
|
|
if not os.path.exists(file_path):
|
|
raise HTTPException(status_code=404, detail="Result image file not found")
|
|
|
|
return FileResponse(file_path, media_type="image/png")
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8005) |