# main.py from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Security from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.security import APIKeyHeader from kafka import KafkaProducer from redis import Redis import os import json import uuid from datetime import datetime, timedelta, timezone import string from decord import VideoReader from PIL import Image from fastapi.responses import FileResponse import logging from config import * app = FastAPI() v1_app = FastAPI() app.mount("/v1", v1_app) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 配置 KAFKA_BROKER = KAFKA_BROKER REDIS_HOST = REDIS_HOST REDIS_PORT = REDIS_PORT REDIS_PASSWORD = REDIS_PASSWORD REDIS_DB = MAIN_REDIS_DB REDIS_API_DB = REDIS_API_DB REDIS_API_USAGE_DB = REDIS_API_USAGE_DB UPLOAD_DIR = UPLOAD_DIR RESULT_DIR = RESULT_DIR MAX_FILE_AGE = timedelta(hours=1) # 确保目录存在 os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(RESULT_DIR, exist_ok=True) # 定义支持的任务类型 KAFKA_TOPICS = { 'pose': 'pose', 'mediapipe': 'mediapipe', 'qwenvl': 'qwenvl', 'yolo': 'yolo', 'fall': 'fall', 'face': 'face', 'cpm': 'cpm', 'compare': 'compare', 'qwenvl_analyze': 'qwenvl_analyze', 'cpm_analyze': 'cpm_analyze' } TASK_TYPES = list(KAFKA_TOPICS.keys()) # 初始化 Kafka Producer producer = KafkaProducer( bootstrap_servers=[KAFKA_BROKER], value_serializer=lambda v: json.dumps(v).encode('utf-8') ) # 初始化 Redis redis_client = Redis( host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DB ) redis_api_client = Redis( host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB ) redis_api_usage_client = Redis( host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB ) redis_pose_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['pose']['redis_db']) redis_cpm_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm']['redis_db']) redis_yolo_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['yolo']['redis_db']) redis_face_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['face']['redis_db']) redis_fall_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['fall']['redis_db']) redis_mediapipe_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['mediapipe']['redis_db']) redis_qwenvl_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl']['redis_db']) redis_compare_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['compare']['redis_db']) redis_qwenvl_analyze_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl_analyze']['redis_db']) redis_cpm_analyze_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm_analyze']['redis_db']) @v1_app.get('/favicon.ico', include_in_schema=False) async def favicon(): file_name = "favicon.ico" file_path = os.path.join(app.root_path, "static", file_name) if os.path.isfile(file_path): return FileResponse(file_path) else: return {"message": "Favicon not found"}, 404 # 定义API密钥头部 api_key_header = APIKeyHeader(name="Authorization", auto_error=False) # 定义 base62 字符集 BASE62 = string.digits + string.ascii_letters # 验证API密钥的函数 async def get_api_key(api_key: str = Security(api_key_header)): if api_key and api_key.startswith("Bearer "): key = api_key.split(" ")[1] if key.startswith("obs-"): return key raise HTTPException( status_code=401, detail="无效的API密钥", headers={"WWW-Authenticate": "Bearer"}, ) async def verify_api_key(api_key: str = Depends(get_api_key)): logging.info(f"验证API密钥: {api_key}") redis_key = f"api_key:{api_key}" api_key_info = redis_api_client.hgetall(redis_key) if not api_key_info: logging.warning(f"API密钥不存在: {api_key}") raise HTTPException(status_code=401, detail="无效的API密钥") api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} if api_key_info.get('is_active') != '1': logging.warning(f"API密钥已停用: {api_key}") raise HTTPException(status_code=401, detail="API密钥已停用") expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) if datetime.now(timezone.utc) > expires_at: logging.warning(f"API密钥已过期: {api_key}") raise HTTPException(status_code=401, detail="API密钥已过期") usage_info = redis_api_usage_client.hgetall(redis_key) usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} logging.info(f"API密钥验证成功: {api_key}") return { "api_key": api_key, **api_key_info, **usage_info } async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): redis_key = f"api_key:{api_key}" current_time = datetime.now(timezone.utc).isoformat() pipe = redis_api_usage_client.pipeline() # 更新总的token使用量 pipe.hincrby(redis_key, "tokens_used", new_tokens_used) pipe.hset(redis_key, "last_used_at", current_time) # 更新特定模型的token使用量 model_tokens_field = f"{model_name}_tokens_used" model_last_used_field = f"{model_name}_last_used_at" pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) pipe.hset(redis_key, model_last_used_field, current_time) pipe.execute() def calculate_tokens(file_path: str, file_type: str) -> int: base_tokens = 0 try: file_size = os.path.getsize(file_path) # 获取文件大小(字节) # 基础token:每MB文件大小消耗10个token base_tokens = int((file_size / (1024 * 1024)) * 0.1) if file_type == "image": img = Image.open(file_path) width, height = img.size pixel_count = width * height # 图片token:每100个像素额外消耗5个token image_tokens = int((pixel_count / 100000000) * 0.1) base_tokens += image_tokens elif file_type == "video": vr = VideoReader(file_path) fps = vr.get_avg_fps() frame_count = len(vr) width, height = vr[0].shape[1], vr[0].shape[0] pixel_count = width * height * frame_count duration = frame_count / fps # 视频时长(秒) # 视频token:每100万像素每秒额外消耗1个token video_tokens = int((pixel_count / 100000000) * (duration / 60)) base_tokens += video_tokens return max(1, base_tokens) # 确保至少返回1个token except Exception as e: print(f"计算token时出错: {str(e)}") return 1 # 出错时返回默认值1 async def upload_file(file: UploadFile, task_type: str, api_key_info: dict): if task_type not in KAFKA_TOPICS: raise HTTPException(status_code=400, detail="不支持的任务类型") content = await file.read() file_extension = os.path.splitext(file.filename)[1].lower() new_filename = f"{uuid.uuid4()}{file_extension}" file_path = os.path.join(UPLOAD_DIR, new_filename) with open(file_path, "wb") as f: f.write(content) # 计算 token file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" tokens_required = calculate_tokens(file_path, file_type) # 检查并更新 token 使用量 api_key = api_key_info['api_key'] usage_key = f"api_key:{api_key}" total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) if tokens_used + tokens_required > total_tokens: raise HTTPException(status_code=403, detail="Token 余额不足") # 更新 token 使用量 await update_token_usage(api_key, tokens_required, task_type) # 创建任务记录 task_id = str(uuid.uuid4()) task_data = { "task_id": task_id, "filename": new_filename, "file_type": file_type, "task_type": task_type, "status": "queued", "created_at": datetime.now(timezone.utc).isoformat(), } # 存储任务信息到Redis redis_client.set(f"task:{task_id}", json.dumps(task_data)) logging.info(f"任务信息已存储到Redis: {task_id}") # 发送任务到对应的Kafka主题 kafka_topic = KAFKA_TOPICS[task_type] producer.send(kafka_topic, task_data) logging.info(f"任务已发送到Kafka主题: {kafka_topic}") # 获取更新后的 token 使用情况 updated_api_key_info = await verify_api_key(api_key) new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) model_tokens_used = int(updated_api_key_info.get(f"{task_type}_tokens_used", 0)) response_data = { "message": "文件已上传并排队等待处理", "task_id": task_id, "tokens_used": tokens_required, "total_tokens_used": new_tokens_used, f"{task_type}_tokens_used": model_tokens_used, "tokens_remaining": total_tokens - new_tokens_used } logging.info(f"上传文件完成: {task_id}") return JSONResponse(content=response_data) # 为每个任务类型创建单独的端点 @v1_app.post("/pose") async def upload_pose(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): logging.info(f"收到 /pose端点的请求") return await upload_file(file, task_type="pose", api_key_info=api_key_info) @v1_app.post("/cpm") async def upload_cpm(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): return await upload_file(file, task_type="cpm", api_key_info=api_key_info) @v1_app.post("/qwenvl") async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): return await upload_file(file, task_type="qwenvl", api_key_info=api_key_info) @v1_app.post("/qwenvl_analyze") async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): return await upload_file(file, task_type="qwenvl_analyze", api_key_info=api_key_info) @v1_app.post("/cpm_analyze") async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): return await upload_file(file, task_type="cpm_analyze", api_key_info=api_key_info) @v1_app.post("/yolo") async def upload_yolo(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): return await upload_file(file, task_type="yolo", api_key_info=api_key_info) @v1_app.post("/fall") async def upload_fall(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): return await upload_file(file, task_type="fall", api_key_info=api_key_info) @v1_app.post("/face") async def upload_face(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): return await upload_file(file, task_type="face", api_key_info=api_key_info) @v1_app.post("/mediapipe") async def upload_mediapipe(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): return await upload_file(file, task_type="mediapipe", api_key_info=api_key_info) @v1_app.post("/compare") async def upload_compare(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): return await upload_file(file, task_type="compare", api_key_info=api_key_info) @v1_app.get("/result/{task_id}") async def get_result(task_id: str, api_key: str = Depends(get_api_key)): api_key_info = await verify_api_key(api_key) if not api_key_info: raise HTTPException(status_code=403, detail="无效的API密钥") # 从 REDIS_DB (15) 获取任务状态 task_info = redis_client.hgetall(f"task:{task_id}") if not task_info: raise HTTPException(status_code=404, detail="Task not found") task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()} if task_info['status'] != 'completed': return {"status": task_info['status'], "message": "Task is not completed yet"} result_type = task_info['result_type'] result_key = task_info['result_key'] # 根据任务类型选择相应的 Redis 客户端 redis_client_map = { 'pose': redis_pose_client, 'cpm': redis_cpm_client, 'yolo': redis_yolo_client, 'face': redis_face_client, 'fall': redis_fall_client, 'mediapipe': redis_mediapipe_client, 'qwenvl': redis_qwenvl_client, 'compare': redis_compare_client, 'qwenvl_analyze': redis_qwenvl_analyze_client, 'cpm_analyze': redis_cpm_analyze_client } result_redis = redis_client_map.get(result_type) if not result_redis: raise HTTPException(status_code=400, detail="Unsupported result type") result = result_redis.hgetall(result_key) if not result: raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found") result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()} # 将 result 字段解析为 JSON(如果存在) if 'result' in result: result['result'] = json.loads(result['result']) return { "status": "completed", "result_type": result_type, "result": result } @v1_app.get("/annotated/{task_id}") async def get_annotated_image(task_id: str, api_key: str = Depends(get_api_key)): api_key_info = await verify_api_key(api_key) if not api_key_info: raise HTTPException(status_code=403, detail="无效的API密钥") # 从 REDIS_DB (15) 获取任务信息 task_info = redis_client.hgetall(f"task:{task_id}") if not task_info: raise HTTPException(status_code=404, detail="Task not found") task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()} if task_info['status'] != 'completed': raise HTTPException(status_code=400, detail="Task is not completed yet") result_type = task_info.get('result_type') result_key = task_info.get('result_key') if not result_key: raise HTTPException(status_code=404, detail="Result key not found") if result_type in ['cpm', 'qwenvl','cpm_analyze', 'qwenvl_analyze']: raise HTTPException(status_code=400, detail="Annotated image not available for this task type") # 根据任务类型选择相应的 Redis 客户端 redis_client_map = { 'pose': redis_pose_client, 'yolo': redis_yolo_client, 'face': redis_face_client, 'fall': redis_fall_client, 'mediapipe': redis_mediapipe_client, 'compare': redis_compare_client } result_redis = redis_client_map.get(result_type) if not result_redis: raise HTTPException(status_code=400, detail="Unsupported result type") result = result_redis.hgetall(result_key) if not result: raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found") result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()} result_file = result.get('result_file') if not result_file: raise HTTPException(status_code=404, detail="Result file not found") file_path = os.path.join(RESULT_DIR, result_file) if not os.path.exists(file_path): raise HTTPException(status_code=404, detail="Result image file not found") return FileResponse(file_path, media_type="image/png") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8005)