Files
2025-04-10 09:45:41 +00:00

438 lines
16 KiB
Python
Executable File

# main.py
from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.security import APIKeyHeader
from kafka import KafkaProducer
from redis import Redis
import os
import json
import uuid
from datetime import datetime, timedelta, timezone
import string
from decord import VideoReader
from PIL import Image
from fastapi.responses import FileResponse
import logging
from config import *
app = FastAPI()
v1_app = FastAPI()
app.mount("/v1", v1_app)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 配置
KAFKA_BROKER = KAFKA_BROKER
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = MAIN_REDIS_DB
REDIS_API_DB = REDIS_API_DB
REDIS_API_USAGE_DB = REDIS_API_USAGE_DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
MAX_FILE_AGE = timedelta(hours=1)
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 定义支持的任务类型
KAFKA_TOPICS = {
'pose': 'pose',
'mediapipe': 'mediapipe',
'qwenvl': 'qwenvl',
'yolo': 'yolo',
'fall': 'fall',
'face': 'face',
'cpm': 'cpm',
'compare': 'compare',
'qwenvl_analyze': 'qwenvl_analyze',
'cpm_analyze': 'cpm_analyze'
}
TASK_TYPES = list(KAFKA_TOPICS.keys())
# 初始化 Kafka Producer
producer = KafkaProducer(
bootstrap_servers=[KAFKA_BROKER],
value_serializer=lambda v: json.dumps(v).encode('utf-8')
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
redis_api_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_API_DB
)
redis_api_usage_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_API_USAGE_DB
)
redis_pose_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['pose']['redis_db'])
redis_cpm_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm']['redis_db'])
redis_yolo_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['yolo']['redis_db'])
redis_face_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['face']['redis_db'])
redis_fall_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['fall']['redis_db'])
redis_mediapipe_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['mediapipe']['redis_db'])
redis_qwenvl_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl']['redis_db'])
redis_compare_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['compare']['redis_db'])
redis_qwenvl_analyze_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl_analyze']['redis_db'])
redis_cpm_analyze_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm_analyze']['redis_db'])
@v1_app.get('/favicon.ico', include_in_schema=False)
async def favicon():
file_name = "favicon.ico"
file_path = os.path.join(app.root_path, "static", file_name)
if os.path.isfile(file_path):
return FileResponse(file_path)
else:
return {"message": "Favicon not found"}, 404
# 定义API密钥头部
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
# 定义 base62 字符集
BASE62 = string.digits + string.ascii_letters
# 验证API密钥的函数
async def get_api_key(api_key: str = Security(api_key_header)):
if api_key and api_key.startswith("Bearer "):
key = api_key.split(" ")[1]
if key.startswith("obs-"):
return key
raise HTTPException(
status_code=401,
detail="无效的API密钥",
headers={"WWW-Authenticate": "Bearer"},
)
async def verify_api_key(api_key: str = Depends(get_api_key)):
logging.info(f"验证API密钥: {api_key}")
redis_key = f"api_key:{api_key}"
api_key_info = redis_api_client.hgetall(redis_key)
if not api_key_info:
logging.warning(f"API密钥不存在: {api_key}")
raise HTTPException(status_code=401, detail="无效的API密钥")
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
if api_key_info.get('is_active') != '1':
logging.warning(f"API密钥已停用: {api_key}")
raise HTTPException(status_code=401, detail="API密钥已停用")
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
if datetime.now(timezone.utc) > expires_at:
logging.warning(f"API密钥已过期: {api_key}")
raise HTTPException(status_code=401, detail="API密钥已过期")
usage_info = redis_api_usage_client.hgetall(redis_key)
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
logging.info(f"API密钥验证成功: {api_key}")
return {
"api_key": api_key,
**api_key_info,
**usage_info
}
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
redis_key = f"api_key:{api_key}"
current_time = datetime.now(timezone.utc).isoformat()
pipe = redis_api_usage_client.pipeline()
# 更新总的token使用量
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
pipe.hset(redis_key, "last_used_at", current_time)
# 更新特定模型的token使用量
model_tokens_field = f"{model_name}_tokens_used"
model_last_used_field = f"{model_name}_last_used_at"
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
pipe.hset(redis_key, model_last_used_field, current_time)
pipe.execute()
def calculate_tokens(file_path: str, file_type: str) -> int:
base_tokens = 0
try:
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
# 基础token:每MB文件大小消耗10个token
base_tokens = int((file_size / (1024 * 1024)) * 0.1)
if file_type == "image":
img = Image.open(file_path)
width, height = img.size
pixel_count = width * height
# 图片token:每100个像素额外消耗5个token
image_tokens = int((pixel_count / 100000000) * 0.1)
base_tokens += image_tokens
elif file_type == "video":
vr = VideoReader(file_path)
fps = vr.get_avg_fps()
frame_count = len(vr)
width, height = vr[0].shape[1], vr[0].shape[0]
pixel_count = width * height * frame_count
duration = frame_count / fps # 视频时长(秒)
# 视频token:每100万像素每秒额外消耗1个token
video_tokens = int((pixel_count / 100000000) * (duration / 60))
base_tokens += video_tokens
return max(1, base_tokens) # 确保至少返回1个token
except Exception as e:
print(f"计算token时出错: {str(e)}")
return 1 # 出错时返回默认值1
async def upload_file(file: UploadFile, task_type: str, api_key_info: dict):
if task_type not in KAFKA_TOPICS:
raise HTTPException(status_code=400, detail="不支持的任务类型")
content = await file.read()
file_extension = os.path.splitext(file.filename)[1].lower()
new_filename = f"{uuid.uuid4()}{file_extension}"
file_path = os.path.join(UPLOAD_DIR, new_filename)
with open(file_path, "wb") as f:
f.write(content)
# 计算 token
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
tokens_required = calculate_tokens(file_path, file_type)
# 检查并更新 token 使用量
api_key = api_key_info['api_key']
usage_key = f"api_key:{api_key}"
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
if tokens_used + tokens_required > total_tokens:
raise HTTPException(status_code=403, detail="Token 余额不足")
# 更新 token 使用量
await update_token_usage(api_key, tokens_required, task_type)
# 创建任务记录
task_id = str(uuid.uuid4())
task_data = {
"task_id": task_id,
"filename": new_filename,
"file_type": file_type,
"task_type": task_type,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
}
# 存储任务信息到Redis
redis_client.set(f"task:{task_id}", json.dumps(task_data))
logging.info(f"任务信息已存储到Redis: {task_id}")
# 发送任务到对应的Kafka主题
kafka_topic = KAFKA_TOPICS[task_type]
producer.send(kafka_topic, task_data)
logging.info(f"任务已发送到Kafka主题: {kafka_topic}")
# 获取更新后的 token 使用情况
updated_api_key_info = await verify_api_key(api_key)
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
model_tokens_used = int(updated_api_key_info.get(f"{task_type}_tokens_used", 0))
response_data = {
"message": "文件已上传并排队等待处理",
"task_id": task_id,
"tokens_used": tokens_required,
"total_tokens_used": new_tokens_used,
f"{task_type}_tokens_used": model_tokens_used,
"tokens_remaining": total_tokens - new_tokens_used
}
logging.info(f"上传文件完成: {task_id}")
return JSONResponse(content=response_data)
# 为每个任务类型创建单独的端点
@v1_app.post("/pose")
async def upload_pose(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
logging.info(f"收到 /pose端点的请求")
return await upload_file(file, task_type="pose", api_key_info=api_key_info)
@v1_app.post("/cpm")
async def upload_cpm(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="cpm", api_key_info=api_key_info)
@v1_app.post("/qwenvl")
async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="qwenvl", api_key_info=api_key_info)
@v1_app.post("/qwenvl_analyze")
async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="qwenvl_analyze", api_key_info=api_key_info)
@v1_app.post("/cpm_analyze")
async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="cpm_analyze", api_key_info=api_key_info)
@v1_app.post("/yolo")
async def upload_yolo(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="yolo", api_key_info=api_key_info)
@v1_app.post("/fall")
async def upload_fall(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="fall", api_key_info=api_key_info)
@v1_app.post("/face")
async def upload_face(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="face", api_key_info=api_key_info)
@v1_app.post("/mediapipe")
async def upload_mediapipe(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="mediapipe", api_key_info=api_key_info)
@v1_app.post("/compare")
async def upload_compare(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="compare", api_key_info=api_key_info)
@v1_app.get("/result/{task_id}")
async def get_result(task_id: str, api_key: str = Depends(get_api_key)):
api_key_info = await verify_api_key(api_key)
if not api_key_info:
raise HTTPException(status_code=403, detail="无效的API密钥")
# 从 REDIS_DB (15) 获取任务状态
task_info = redis_client.hgetall(f"task:{task_id}")
if not task_info:
raise HTTPException(status_code=404, detail="Task not found")
task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()}
if task_info['status'] != 'completed':
return {"status": task_info['status'], "message": "Task is not completed yet"}
result_type = task_info['result_type']
result_key = task_info['result_key']
# 根据任务类型选择相应的 Redis 客户端
redis_client_map = {
'pose': redis_pose_client,
'cpm': redis_cpm_client,
'yolo': redis_yolo_client,
'face': redis_face_client,
'fall': redis_fall_client,
'mediapipe': redis_mediapipe_client,
'qwenvl': redis_qwenvl_client,
'compare': redis_compare_client,
'qwenvl_analyze': redis_qwenvl_analyze_client,
'cpm_analyze': redis_cpm_analyze_client
}
result_redis = redis_client_map.get(result_type)
if not result_redis:
raise HTTPException(status_code=400, detail="Unsupported result type")
result = result_redis.hgetall(result_key)
if not result:
raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found")
result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
# 将 result 字段解析为 JSON(如果存在)
if 'result' in result:
result['result'] = json.loads(result['result'])
return {
"status": "completed",
"result_type": result_type,
"result": result
}
@v1_app.get("/annotated/{task_id}")
async def get_annotated_image(task_id: str, api_key: str = Depends(get_api_key)):
api_key_info = await verify_api_key(api_key)
if not api_key_info:
raise HTTPException(status_code=403, detail="无效的API密钥")
# 从 REDIS_DB (15) 获取任务信息
task_info = redis_client.hgetall(f"task:{task_id}")
if not task_info:
raise HTTPException(status_code=404, detail="Task not found")
task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()}
if task_info['status'] != 'completed':
raise HTTPException(status_code=400, detail="Task is not completed yet")
result_type = task_info.get('result_type')
result_key = task_info.get('result_key')
if not result_key:
raise HTTPException(status_code=404, detail="Result key not found")
if result_type in ['cpm', 'qwenvl','cpm_analyze', 'qwenvl_analyze']:
raise HTTPException(status_code=400, detail="Annotated image not available for this task type")
# 根据任务类型选择相应的 Redis 客户端
redis_client_map = {
'pose': redis_pose_client,
'yolo': redis_yolo_client,
'face': redis_face_client,
'fall': redis_fall_client,
'mediapipe': redis_mediapipe_client,
'compare': redis_compare_client
}
result_redis = redis_client_map.get(result_type)
if not result_redis:
raise HTTPException(status_code=400, detail="Unsupported result type")
result = result_redis.hgetall(result_key)
if not result:
raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found")
result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
result_file = result.get('result_file')
if not result_file:
raise HTTPException(status_code=404, detail="Result file not found")
file_path = os.path.join(RESULT_DIR, result_file)
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="Result image file not found")
return FileResponse(file_path, media_type="image/png")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8005)