453 lines
16 KiB
Python
453 lines
16 KiB
Python
from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Header
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from fastapi.security import APIKeyHeader
|
|
import cv2
|
|
import numpy as np
|
|
import json
|
|
import uvicorn
|
|
from kafka import KafkaProducer, KafkaConsumer
|
|
from redis import Redis
|
|
import uuid
|
|
import os
|
|
from datetime import datetime, timedelta, timezone
|
|
import threading
|
|
import mediapipe as mp
|
|
from mediapipe.tasks import python
|
|
from mediapipe.tasks.python import vision
|
|
|
|
app = FastAPI()
|
|
mediapipe_app = FastAPI()
|
|
app.mount("/mediapipe", mediapipe_app)
|
|
|
|
# CORS configuration
|
|
ALLOWED_ORIGINS = 'https://beta.obscura.work'
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=ALLOWED_ORIGINS,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Configuration
|
|
MODEL_PATH = "/home/zydi/models/face_landmarker.task" # Replace with your model path
|
|
KAFKA_BROKER = "222.186.10.253:9092"
|
|
KAFKA_TOPIC = "mediapipe"
|
|
KAFKA_GROUP_ID = "mediapipe_group"
|
|
|
|
REDIS_HOST = "222.186.10.253"
|
|
REDIS_PORT = 6379
|
|
REDIS_PASSWORD = "Obscura@2024"
|
|
REDIS_DB = 10
|
|
REDIS_API_DB = 12
|
|
REDIS_API_USAGE_DB = 13
|
|
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
|
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
|
MAX_FILE_AGE = timedelta(hours=1)
|
|
|
|
# Ensure directories exist
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|
os.makedirs(RESULT_DIR, exist_ok=True)
|
|
|
|
# Initialize Kafka
|
|
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
|
consumer = KafkaConsumer(
|
|
KAFKA_TOPIC,
|
|
bootstrap_servers=[KAFKA_BROKER],
|
|
group_id=KAFKA_GROUP_ID,
|
|
auto_offset_reset='earliest',
|
|
enable_auto_commit=True,
|
|
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
|
)
|
|
|
|
# Initialize Redis
|
|
redis_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=REDIS_DB
|
|
)
|
|
|
|
redis_api_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=REDIS_API_DB
|
|
)
|
|
redis_api_usage_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=REDIS_API_USAGE_DB
|
|
)
|
|
|
|
# 添加API密钥验证
|
|
API_KEY_NAME = "X-API-Key"
|
|
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
|
|
|
async def get_api_key(api_key: str = Header(None, alias=API_KEY_NAME)):
|
|
if api_key is None:
|
|
raise HTTPException(status_code=400, detail="API密钥缺失")
|
|
return api_key
|
|
|
|
def calculate_tokens(file_path: str, file_type: str) -> int:
|
|
base_tokens = 0
|
|
|
|
try:
|
|
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
|
|
|
# 基础token:每MB文件大小消耗10个token
|
|
base_tokens = int((file_size / (1024 * 1024)) * 10)
|
|
|
|
if file_type == "image":
|
|
img = cv2.imread(file_path)
|
|
if img is None:
|
|
raise ValueError("无法读取图片文件")
|
|
height, width = img.shape[:2]
|
|
pixel_count = height * width
|
|
|
|
# 图片token:每100个像素额外消耗5个token
|
|
image_tokens = int((pixel_count / 10000) * 5)
|
|
|
|
base_tokens += image_tokens
|
|
|
|
elif file_type == "video":
|
|
cap = cv2.VideoCapture(file_path)
|
|
if not cap.isOpened():
|
|
raise ValueError("无法打开视频文件")
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
cap.release()
|
|
|
|
pixel_count = width * height * frame_count
|
|
duration = frame_count / fps # 视频时长(秒)
|
|
|
|
# 视频token:每100万像素每秒额外消耗1个token
|
|
video_tokens = int((pixel_count / 10000) * (duration / 60))
|
|
|
|
base_tokens += video_tokens
|
|
|
|
return max(1, base_tokens) # 确保至少返回1个token
|
|
except Exception as e:
|
|
print(f"计算token时出错: {str(e)}")
|
|
return 1 # 出错时返回默认值1
|
|
|
|
|
|
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
|
redis_key = f"api_key:{api_key}"
|
|
|
|
api_key_info = redis_api_client.hgetall(redis_key)
|
|
|
|
if not api_key_info:
|
|
return None
|
|
|
|
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
|
|
|
if api_key_info.get('is_active') != '1':
|
|
return None
|
|
|
|
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
|
if datetime.now(timezone.utc) > expires_at:
|
|
return None
|
|
|
|
usage_info = redis_api_usage_client.hgetall(redis_key)
|
|
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
|
|
|
return {
|
|
**api_key_info,
|
|
**usage_info
|
|
}
|
|
|
|
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
|
redis_key = f"api_key:{api_key}"
|
|
current_time = datetime.now(timezone.utc).isoformat()
|
|
|
|
pipe = redis_api_usage_client.pipeline()
|
|
|
|
# 更新总的token使用量
|
|
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
|
pipe.hset(redis_key, "last_used_at", current_time)
|
|
|
|
# 更新特定模型的token使用量
|
|
model_tokens_field = f"{model_name}_tokens_used"
|
|
model_last_used_field = f"{model_name}_last_used_at"
|
|
|
|
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
|
pipe.hset(redis_key, model_last_used_field, current_time)
|
|
|
|
pipe.execute()
|
|
|
|
class mediapipeEmbedder:
|
|
def __init__(self, model_path):
|
|
base_options = python.BaseOptions(model_asset_path=model_path)
|
|
options = vision.FaceLandmarkerOptions(
|
|
base_options=base_options,
|
|
output_face_blendshapes=True,
|
|
output_facial_transformation_matrixes=True,
|
|
num_faces=1
|
|
)
|
|
self.detector = vision.FaceLandmarker.create_from_options(options)
|
|
|
|
def get_mediapipe_landmarks(self, image):
|
|
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
|
|
detection_result = self.detector.detect(mp_image)
|
|
if detection_result.face_landmarks:
|
|
return np.array([(lm.x, lm.y, lm.z) for lm in detection_result.face_landmarks[0]])
|
|
return None
|
|
|
|
def process_image(self, image_data):
|
|
img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
|
landmarks = self.get_mediapipe_landmarks(img)
|
|
|
|
if landmarks is not None:
|
|
# Calculate a more detailed mediapipe embedding
|
|
embedding = self.calculate_detailed_embedding(landmarks)
|
|
|
|
# Draw landmarks on the image
|
|
for lm in landmarks:
|
|
cv2.circle(img, (int(lm[0]*img.shape[1]), int(lm[1]*img.shape[0])), 2, (0,255,0), -1)
|
|
|
|
return {
|
|
"embedding": embedding,
|
|
"landmarks": landmarks.tolist()
|
|
}, img
|
|
else:
|
|
return None, img
|
|
|
|
def calculate_detailed_embedding(self, landmarks):
|
|
# Calculate various statistical features
|
|
mean = np.mean(landmarks, axis=0)
|
|
std = np.std(landmarks, axis=0)
|
|
median = np.median(landmarks, axis=0)
|
|
min_vals = np.min(landmarks, axis=0)
|
|
max_vals = np.max(landmarks, axis=0)
|
|
|
|
# Calculate pairwise distances between key facial landmarks
|
|
nose_tip = landmarks[4]
|
|
left_eye = landmarks[159]
|
|
right_eye = landmarks[386]
|
|
left_mouth = landmarks[61]
|
|
right_mouth = landmarks[291]
|
|
|
|
eye_distance = np.linalg.norm(left_eye - right_eye)
|
|
mouth_width = np.linalg.norm(left_mouth - right_mouth)
|
|
nose_to_mouth = np.linalg.norm(nose_tip - (left_mouth + right_mouth) / 2)
|
|
|
|
# Calculate face shape features
|
|
face_width = np.max(landmarks[:, 0]) - np.min(landmarks[:, 0])
|
|
face_height = np.max(landmarks[:, 1]) - np.min(landmarks[:, 1])
|
|
face_depth = np.max(landmarks[:, 2]) - np.min(landmarks[:, 2])
|
|
|
|
# Combine all features into a single embedding
|
|
embedding = np.concatenate([
|
|
mean, std, median, min_vals, max_vals,
|
|
[eye_distance, mouth_width, nose_to_mouth, face_width, face_height, face_depth]
|
|
])
|
|
|
|
return embedding.tolist()
|
|
|
|
embedder = mediapipeEmbedder(MODEL_PATH)
|
|
|
|
def process_image(image_data, filename):
|
|
try:
|
|
results, annotated_img = embedder.process_image(image_data)
|
|
|
|
if results:
|
|
# Save annotated image
|
|
annotated_filename = f"mediapipe_{filename}"
|
|
annotated_path = os.path.join(RESULT_DIR, annotated_filename)
|
|
cv2.imwrite(annotated_path, annotated_img)
|
|
|
|
return results, annotated_filename
|
|
else:
|
|
print(f"No face landmarks detected in image: {filename}")
|
|
return None, None
|
|
except Exception as e:
|
|
print(f"Error processing image {filename}: {str(e)}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return None, None
|
|
|
|
def process_video(video_data, filename):
|
|
try:
|
|
temp_video_path = os.path.join(UPLOAD_DIR, f"mediapipe_{filename}")
|
|
with open(temp_video_path, 'wb') as temp_video:
|
|
temp_video.write(video_data)
|
|
|
|
cap = cv2.VideoCapture(temp_video_path)
|
|
frame_count = 0
|
|
results = []
|
|
|
|
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
|
|
annotated_filename = f"mediapipe_{filename}"
|
|
output_path = os.path.join(RESULT_DIR, annotated_filename)
|
|
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
|
|
|
while cap.isOpened():
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
if frame_count % fps == 0:
|
|
frame_results, annotated_frame = embedder.process_image(cv2.imencode('.jpg', frame)[1].tobytes())
|
|
if frame_results:
|
|
results.append({"frame": frame_count, "results": frame_results})
|
|
else:
|
|
annotated_frame = frame
|
|
|
|
out.write(annotated_frame)
|
|
frame_count += 1
|
|
|
|
cap.release()
|
|
out.release()
|
|
|
|
os.remove(temp_video_path)
|
|
|
|
return results, annotated_filename
|
|
except Exception as e:
|
|
print(f"Error processing video: {str(e)}")
|
|
return None, None
|
|
|
|
@mediapipe_app.post("/upload")
|
|
async def upload_file(file: UploadFile = File(...),api_key: str = Depends(get_api_key)):
|
|
# 验证 API key
|
|
api_key_info = await verify_api_key(api_key)
|
|
if not api_key_info:
|
|
raise HTTPException(status_code=403, detail="无效的API密钥")
|
|
|
|
content = await file.read()
|
|
file_extension = os.path.splitext(file.filename)[1].lower()
|
|
new_filename = f"{uuid.uuid4()}{file_extension}"
|
|
|
|
original_file_path = os.path.join(UPLOAD_DIR, new_filename)
|
|
with open(original_file_path, "wb") as f:
|
|
f.write(content)
|
|
|
|
# 计算 token
|
|
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
|
tokens_required = calculate_tokens(original_file_path, file_type)
|
|
if tokens_required is None or tokens_required <= 0:
|
|
raise HTTPException(status_code=500, detail="无法计算所需的token数量")
|
|
# 检查并更新 token 使用量
|
|
usage_key = f"api_key:{api_key}"
|
|
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
|
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
|
|
|
if tokens_used + tokens_required > total_tokens:
|
|
raise HTTPException(status_code=403, detail="Token 余额不足")
|
|
|
|
# 更新 token 使用量
|
|
model_name = "mediapipe"
|
|
await update_token_usage(api_key, tokens_required, model_name)
|
|
|
|
|
|
producer.send(KAFKA_TOPIC, json.dumps({
|
|
"filename": new_filename,
|
|
"file_type": "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
|
}).encode('utf-8'))
|
|
|
|
redis_key = f"mediapipe_result:{new_filename}"
|
|
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
|
|
|
# 获取更新后的 token 使用情况
|
|
updated_api_key_info = await verify_api_key(api_key)
|
|
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
|
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
|
|
|
return JSONResponse(content={
|
|
"message": "文件已上传并排队等待处理",
|
|
"filename": new_filename,
|
|
"tokens_used": tokens_required,
|
|
"total_tokens_used": new_tokens_used,
|
|
f"{model_name}_tokens_used": model_tokens_used,
|
|
"tokens_remaining": total_tokens - new_tokens_used
|
|
})
|
|
|
|
|
|
@mediapipe_app.get("/result/{filename}", dependencies=[Depends(verify_api_key)])
|
|
async def get_mediapipe_result(filename: str):
|
|
redis_key = f"mediapipe_result:{filename}"
|
|
result = redis_client.get(redis_key)
|
|
if result:
|
|
result_data = json.loads(result)
|
|
return JSONResponse(content=result_data)
|
|
else:
|
|
raise HTTPException(status_code=404, detail="Result not found")
|
|
|
|
@mediapipe_app.get("/annotated/{filename}", dependencies=[Depends(verify_api_key)])
|
|
async def get_annotated_file(filename: str):
|
|
redis_key = f"mediapipe_result:{filename}"
|
|
result = redis_client.get(redis_key)
|
|
if result:
|
|
result_data = json.loads(result)
|
|
if result_data["status"] == "completed":
|
|
annotated_filename = result_data["annotated_filename"]
|
|
file_path = os.path.join(RESULT_DIR, annotated_filename)
|
|
if os.path.exists(file_path):
|
|
def iterfile():
|
|
with open(file_path, mode="rb") as file_like:
|
|
yield from file_like
|
|
file_extension = os.path.splitext(annotated_filename)[1].lower()
|
|
return StreamingResponse(iterfile(), media_type=f"image/{file_extension[1:]}" if file_extension in ['.jpg', '.jpeg', '.png'] else "video/mp4")
|
|
|
|
raise HTTPException(status_code=404, detail="Annotated file not found")
|
|
|
|
def process_task():
|
|
for message in consumer:
|
|
task = message.value
|
|
filename = task['filename']
|
|
file_type = task['file_type']
|
|
|
|
file_path = os.path.join(UPLOAD_DIR, filename)
|
|
|
|
redis_key = f"mediapipe_result:{filename}"
|
|
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
|
|
|
try:
|
|
if file_type == "image":
|
|
with open(file_path, 'rb') as f:
|
|
content = f.read()
|
|
results, annotated_filename = process_image(content, filename)
|
|
else:
|
|
with open(file_path, 'rb') as f:
|
|
content = f.read()
|
|
results, annotated_filename = process_video(content, filename)
|
|
|
|
if results and annotated_filename:
|
|
redis_client.set(redis_key, json.dumps({
|
|
"results": results,
|
|
"status": "completed",
|
|
"annotated_filename": annotated_filename
|
|
}))
|
|
else:
|
|
redis_client.set(redis_key, json.dumps({"status": "failed"}))
|
|
except Exception as e:
|
|
print(f"Error processing task: {str(e)}")
|
|
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
|
|
|
def listen_redis_changes():
|
|
pubsub = redis_client.pubsub()
|
|
pubsub.psubscribe('__keyspace@3__:mediapipe_result:*')
|
|
|
|
for message in pubsub.listen():
|
|
if message['type'] == 'pmessage':
|
|
key = message['channel'].decode('utf-8').split(':')[-1]
|
|
operation = message['data'].decode('utf-8')
|
|
|
|
if operation == 'set':
|
|
value = redis_client.get(f"mediapipe_result:{key}")
|
|
if value:
|
|
result = json.loads(value)
|
|
print(f"Status update for {key}: {result['status']}")
|
|
|
|
if __name__ == "__main__":
|
|
threading.Thread(target=process_task, daemon=True).start()
|
|
threading.Thread(target=listen_redis_changes, daemon=True).start()
|
|
uvicorn.run(app, host="0.0.0.0", port=7006) |