This commit is contained in:
2025-01-12 06:58:52 +00:00
parent a3dcc7a619
commit b22b949620
15 changed files with 1182 additions and 97 deletions
+1
View File
@@ -18,3 +18,4 @@ api_chat/runtime/*
api_chat/docs/* api_chat/docs/*
!api_chat/docs/.gitkeep !api_chat/docs/.gitkeep
+140
View File
@@ -0,0 +1,140 @@
# API集合
该项目为所有API集合,集成了视觉分析、聊天对话和语音处理等功能。
## 项目结构
```
API/
├── api/ # 视觉分析和处理模块
│ ├── producer/ # 主程序入口,生产者,分配任务
│ ├── cpm_analyze.py # CPM_OCR模型分析
│ ├── qwenvl_analyze.py # QwenVL_OCR模型分析
│ ├── cpm_scene.py # CPM_场景模型分析
│ ├── qwenvl_scene.py # QwenVL_场景模型分析
│ ├── compare.py # 人脸对比模型
│ ├── yolo.py # YOLO目标检测
│ ├── face.py # 人脸检测
│ ├── fall.py # 跌倒检测
│ ├── pose.py # 姿态估计
│ └── media.py # 媒体处理
├── api_chat/ # 聊天和语音处理模块
│ ├── producer_chat/ # 聊天生产者
│ ├── chat.py # 聊天功能
│ ├── tts.py # 文字转语音
│ ├── asr.py # 语音识别
│ ├── GPT_SoVITS/ # GPT_SoVITS模型集成,
│ ├── sample/ # OpenBMB模型——学习音色,音色+文本内容,
│ ├── tools/ # GPT_SoVITS模型——工具函数
│ ├── runtime/ # GPT_SoVITS模型——运行时函数
│ ├── docs/ # GPT_SoVITS模型——文档
│ ├── TEMP/ # OpenBMB模型临时文件夹,
│ └── before/ # 历史代码,可以忽略
├── api_history/ # api历史代码,可以忽略
├── chat_history/ # api_chat历史代码,可以忽略
└── api_old/ # api历史代码,可以忽略
```
## 主要功能
### 视觉分析模块 (api/)
- 目标检测和跟踪
- 人脸识别
- 人脸对比
- 姿态估计
- 跌倒检测
- 场景理解(基于CPM和QwenVL模型)
### 聊天对话模块 (api_chat/)
- 文本对话功能
- 语音识别 (ASR): 通过Whisper模型
- 文字转语音 (TTS): 通过GPT_SoVITS模型
- 多模型支持(通过Ollama
## 使用说明
### API 部分 http://dev.obscura.work/v1
1. producer 目录 # 生产者,分配任务
2. 服务器:222.186.10.253:8005
3. kafka 配置:222.186.10.253:9092
topic分配:
- yolo: "yolo"
- pose: "pose"
- qwenvl: "qwenvl"
- qwenvl_analyze: "qwenvl_analyze"
- cpm: "cpm"
- cpm_analyze: "cpm_analyze"
- fall: "fall"
- face: "face"
- mediapipe: "mediapipe"
- compare: "compare"
4. redis 配置:150.158.144.159:13003
db分配:
- 4: "yolo"
- 5: "pose"
- 9: "qwenvl"
- 32: "qwenvl_analyze"
- 8: "cpm"
- 31: "cpm_analyze"
- 6: "fall"
- 7: "face"
- 10: "mediapipe"
- 30: "compare"
5. 模型配置:
- YOLO = "/obscura/models/yolov8x.pt"
- POSE = "/obscura/models/yolov8x-pose.pt"
- QWEN = "/obscura/models/qwen/Qwen2-VL-2B-Instruct"
- FALL = "/obscura/models/yolov8n-fall.pt"
- FACE = "/obscura/models/yolov8n-face.pt"
- MEDIAPIPE = "/obscura/models/face_landmarker.task"
- CPM(ollama) = "https://ffgregevrdcfyhtnhyudvr.myfastools.com/api/generate"
6. 上传文件及结果保存目录:
- UPLOAD_DIR = "/obscura/task/upload"
- RESULT_DIR = "/obscura/task/result"
### API_Chat 部分 http://dev.obscura.work/v1_chat
1. producer_chat 目录 # 聊天生产者
2. 服务器:222.186.10.253:8008
3. kafka 配置:222.186.10.253:9092
topic分配:
- asr: "asr"
- chat: "chat"
- tts: "tts"
4. redis 配置:150.158.144.159:13003
db分配:
- 2: "api key"
- 3: "api使用情况"
- 11: "task 任务记录"
- 12: "asr 记录"
- 13: "chat 记录"
- 14: "tts 记录"
#语言
- 63: "session_zh 中文"
- 62: "session_en 英文"
- 61: "session_ko 韩语"
#音色
- 15: "girl"
- 16: "woman"
- 17: "man"
- 18: "leijun"
- 19: "dufu"
- 20: "hejiong"
- 21: "mahuateng"
- 22: "lidan"
- 23: "dabing"
- 24: "luoxiang"
- 25: "xuzhiyuan"
- 26: "yuhua"
- 27: "liuzhenyun"
5. 音频文件保存目录:
- OUTPUT_PATH=/obscura/task/audio_files # 音频文件保存目录
## 注意事项
- 注意redis的db分配
- 注意kafka的topic分配
- 注意producer的config.py环境配置
- 注意producer_chat的.env环境配置
- 确保模型权重文件已正确配置
- 检查API密钥和环境变量设置
- 注意资源使用和性能优化
-97
View File
@@ -1,97 +0,0 @@
# 将本地Ollama API完全反向代理
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import httpx
OLLAMA_URL = "http://127.0.0.1:11434"
app = FastAPI()
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有来源
allow_credentials=True,
allow_methods=["*"], # 允许所有HTTP方法
allow_headers=["*"], # 允许所有HTTP头
)
# 创建异步HTTP客户端
async_client = httpx.AsyncClient()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
async def proxy_to_ollama(request: Request, path: str):
if request.method == "OPTIONS":
# 处理预检请求
headers = {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "*",
}
return StreamingResponse(content=iter([]), headers=headers)
target_url = f"{OLLAMA_URL}/{path}"
# 获取请求体
body = await request.body()
# 获取请求头
headers = dict(request.headers)
headers.pop("host", None)
try:
# 将请求转换为Python请求
python_request = {
"method": request.method,
"url": target_url,
"headers": headers,
"data": body
}
# 使用Python请求发送到Ollama API
async with async_client.stream(**python_request) as response:
# 返回Ollama API的流式响应,并添加CORS头
response_headers = dict(response.headers)
response_headers["Access-Control-Allow-Origin"] = "*"
return StreamingResponse(
response.aiter_raw(),
status_code=response.status_code,
headers=response_headers
)
except httpx.RequestError as e:
return {"error": f"请求Ollama API时发生错误: {str(e)}"}, 500
except httpx.StreamClosed:
# 处理流关闭异常
print("流已关闭,客户端可能已断开连接")
return {"error": "流已关闭,客户端可能已断开连接"}, 499
@app.on_event("shutdown")
async def shutdown_event():
await async_client.aclose()
if __name__ == "__main__":
import uvicorn
import requests
import json
# 测试Ollama API
test_url = "http://localhost:11434/api/generate"
test_data = {
"model": "llama3.1",
"prompt": "Why is the sky blue?",
"stream": False
}
try:
response = requests.post(test_url, json=test_data)
if response.status_code == 200:
print("Ollama API 测试成功:")
print(json.dumps(response.json(), indent=2))
else:
print(f"Ollama API 测试失败,状态码: {response.status_code}")
print(response.text)
except requests.RequestException as e:
print(f"Ollama API 测试出错: {str(e)}")
# 启动FastAPI应用
uvicorn.run(app, host="0.0.0.0", port=8001)
+80
View File
@@ -0,0 +1,80 @@
# config.py
import os
# Kafka配置
KAFKA_BROKER = "222.186.10.253:9092"
KAFKA_GROUP_ID_PREFIX = "group"
# Redis配置
REDIS_HOST = "150.158.144.159"
REDIS_PORT = 13003
REDIS_PASSWORD = "Obscura@2024"
MAIN_REDIS_DB = 0
REDIS_API_DB = 2
REDIS_API_USAGE_DB = 3
# 目录配置
UPLOAD_DIR = "/obscura/task/upload"
RESULT_DIR = "/obscura/task/result"
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 模型配置
YOLO_MODEL_PATH = "/obscura/models/yolov8x.pt"
POSE_MODEL_PATH = "/obscura/models/yolov8x-pose.pt"
QWEN_MODEL_PATH = "/obscura/models/qwen/Qwen2-VL-2B-Instruct"
FALL_MODEL_PATH = "/obscura/models/yolov8n-fall.pt"
FACE_MODEL_PATH = "/obscura/models/yolov8n-face.pt"
MEDIAPIPE_MODEL_PATH = "/obscura/models/face_landmarker.task"
# Ollama配置
OLLAMA_URL = "https://ffgregevrdcfyhtnhyudvr.myfastools.com/api/generate"
# 各个worker的配置
WORKER_CONFIGS = {
"yolo": {
"kafka_topic": "yolo",
"redis_db": 4,
},
"pose": {
"kafka_topic": "pose",
"redis_db": 5,
},
"qwenvl": {
"kafka_topic": "qwenvl",
"redis_db": 9,
},
"qwenvl_analyze": {
"kafka_topic": "qwenvl_analyze",
"redis_db": 32,
},
"cpm": {
"kafka_topic": "cpm",
"redis_db": 8,
},
"cpm_analyze": {
"kafka_topic": "cpm_analyze",
"redis_db": 31,
},
"fall": {
"kafka_topic": "fall",
"redis_db": 6,
},
"face": {
"kafka_topic": "face",
"redis_db": 7,
},
"mediapipe": {
"kafka_topic": "mediapipe",
"redis_db": 10,
},
"compare": {
"kafka_topic": "compare",
"redis_db": 30,
}
}
# GPU设置
CUDA_DEVICE_0 = "cuda:0"
CUDA_DEVICE_1 = "cuda:1"
+442
View File
@@ -0,0 +1,442 @@
# main.py
from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.security import APIKeyHeader
from kafka import KafkaProducer
from redis import Redis
import os
import json
import uuid
from datetime import datetime, timedelta, timezone
import string
from decord import VideoReader
from PIL import Image
from fastapi.responses import FileResponse
import logging
from config import *
app = FastAPI()
v1_app = FastAPI()
app.mount("/v1", v1_app)
# CORS设置
# ALLOWED_ORIGINS = ['https://beta.obscura.work']
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 配置
KAFKA_BROKER = KAFKA_BROKER
REDIS_HOST = REDIS_HOST
REDIS_PORT = REDIS_PORT
REDIS_PASSWORD = REDIS_PASSWORD
REDIS_DB = MAIN_REDIS_DB
REDIS_API_DB = REDIS_API_DB
REDIS_API_USAGE_DB = REDIS_API_USAGE_DB
UPLOAD_DIR = UPLOAD_DIR
RESULT_DIR = RESULT_DIR
MAX_FILE_AGE = timedelta(hours=1)
# 确保目录存在
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
# 定义支持的任务类型
KAFKA_TOPICS = {
'pose': 'pose',
'mediapipe': 'mediapipe',
'qwenvl': 'qwenvl',
'yolo': 'yolo',
'fall': 'fall',
'face': 'face',
'cpm': 'cpm',
'compare': 'compare',
'qwenvl_analyze': 'qwenvl_analyze',
'cpm_analyze': 'cpm_analyze'
}
TASK_TYPES = list(KAFKA_TOPICS.keys())
# 初始化 Kafka Producer
producer = KafkaProducer(
bootstrap_servers=[KAFKA_BROKER],
value_serializer=lambda v: json.dumps(v).encode('utf-8')
)
# 初始化 Redis
redis_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_DB
)
redis_api_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_API_DB
)
redis_api_usage_client = Redis(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
db=REDIS_API_USAGE_DB
)
redis_pose_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['pose']['redis_db'])
redis_cpm_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm']['redis_db'])
redis_yolo_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['yolo']['redis_db'])
redis_face_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['face']['redis_db'])
redis_fall_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['fall']['redis_db'])
redis_mediapipe_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['mediapipe']['redis_db'])
redis_qwenvl_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl']['redis_db'])
redis_compare_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['compare']['redis_db'])
redis_qwenvl_analyze_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl_analyze']['redis_db'])
redis_cpm_analyze_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm_analyze']['redis_db'])
@v1_app.get('/favicon.ico', include_in_schema=False)
async def favicon():
file_name = "favicon.ico"
file_path = os.path.join(app.root_path, "static", file_name)
if os.path.isfile(file_path):
return FileResponse(file_path)
else:
return {"message": "Favicon not found"}, 404
# 定义API密钥头部
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
# 定义 base62 字符集
BASE62 = string.digits + string.ascii_letters
# 验证API密钥的函数
async def get_api_key(api_key: str = Security(api_key_header)):
if api_key and api_key.startswith("Bearer "):
key = api_key.split(" ")[1]
if key.startswith("obs-"):
return key
raise HTTPException(
status_code=401,
detail="无效的API密钥",
headers={"WWW-Authenticate": "Bearer"},
)
async def verify_api_key(api_key: str = Depends(get_api_key)):
logging.info(f"验证API密钥: {api_key}")
redis_key = f"api_key:{api_key}"
api_key_info = redis_api_client.hgetall(redis_key)
if not api_key_info:
logging.warning(f"API密钥不存在: {api_key}")
raise HTTPException(status_code=401, detail="无效的API密钥")
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
if api_key_info.get('is_active') != '1':
logging.warning(f"API密钥已停用: {api_key}")
raise HTTPException(status_code=401, detail="API密钥已停用")
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
if datetime.now(timezone.utc) > expires_at:
logging.warning(f"API密钥已过期: {api_key}")
raise HTTPException(status_code=401, detail="API密钥已过期")
usage_info = redis_api_usage_client.hgetall(redis_key)
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
logging.info(f"API密钥验证成功: {api_key}")
return {
"api_key": api_key,
**api_key_info,
**usage_info
}
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
redis_key = f"api_key:{api_key}"
current_time = datetime.now(timezone.utc).isoformat()
pipe = redis_api_usage_client.pipeline()
# 更新总的token使用量
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
pipe.hset(redis_key, "last_used_at", current_time)
# 更新特定模型的token使用量
model_tokens_field = f"{model_name}_tokens_used"
model_last_used_field = f"{model_name}_last_used_at"
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
pipe.hset(redis_key, model_last_used_field, current_time)
pipe.execute()
def calculate_tokens(file_path: str, file_type: str) -> int:
base_tokens = 0
try:
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
# 基础token:每MB文件大小消耗10个token
base_tokens = int((file_size / (1024 * 1024)) * 0.1)
if file_type == "image":
img = Image.open(file_path)
width, height = img.size
pixel_count = width * height
# 图片token:每100个像素额外消耗5个token
image_tokens = int((pixel_count / 100000000) * 0.1)
base_tokens += image_tokens
elif file_type == "video":
vr = VideoReader(file_path)
fps = vr.get_avg_fps()
frame_count = len(vr)
width, height = vr[0].shape[1], vr[0].shape[0]
pixel_count = width * height * frame_count
duration = frame_count / fps # 视频时长(秒)
# 视频token:每100万像素每秒额外消耗1个token
video_tokens = int((pixel_count / 100000000) * (duration / 60))
base_tokens += video_tokens
return max(1, base_tokens) # 确保至少返回1个token
except Exception as e:
print(f"计算token时出错: {str(e)}")
return 1 # 出错时返回默认值1
async def upload_file(file: UploadFile, task_type: str, api_key_info: dict):
if task_type not in KAFKA_TOPICS:
raise HTTPException(status_code=400, detail="不支持的任务类型")
content = await file.read()
file_extension = os.path.splitext(file.filename)[1].lower()
new_filename = f"{uuid.uuid4()}{file_extension}"
file_path = os.path.join(UPLOAD_DIR, new_filename)
with open(file_path, "wb") as f:
f.write(content)
# 计算 token
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
tokens_required = calculate_tokens(file_path, file_type)
# 检查并更新 token 使用量
api_key = api_key_info['api_key']
usage_key = f"api_key:{api_key}"
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
if tokens_used + tokens_required > total_tokens:
raise HTTPException(status_code=403, detail="Token 余额不足")
# 更新 token 使用量
await update_token_usage(api_key, tokens_required, task_type)
# 创建任务记录
task_id = str(uuid.uuid4())
task_data = {
"task_id": task_id,
"filename": new_filename,
"file_type": file_type,
"task_type": task_type,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
}
# 存储任务信息到Redis
redis_client.set(f"task:{task_id}", json.dumps(task_data))
logging.info(f"任务信息已存储到Redis: {task_id}")
# 发送任务到对应的Kafka主题
kafka_topic = KAFKA_TOPICS[task_type]
producer.send(kafka_topic, task_data)
logging.info(f"任务已发送到Kafka主题: {kafka_topic}")
# 获取更新后的 token 使用情况
updated_api_key_info = await verify_api_key(api_key)
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
model_tokens_used = int(updated_api_key_info.get(f"{task_type}_tokens_used", 0))
response_data = {
"message": "文件已上传并排队等待处理",
"task_id": task_id,
"tokens_used": tokens_required,
"total_tokens_used": new_tokens_used,
f"{task_type}_tokens_used": model_tokens_used,
"tokens_remaining": total_tokens - new_tokens_used
}
logging.info(f"上传文件完成: {task_id}")
return JSONResponse(content=response_data)
# 为每个任务类型创建单独的端点
@v1_app.post("/pose")
async def upload_pose(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
logging.info(f"收到 /pose端点的请求")
return await upload_file(file, task_type="pose", api_key_info=api_key_info)
@v1_app.post("/cpm")
async def upload_cpm(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="cpm", api_key_info=api_key_info)
@v1_app.post("/qwenvl")
async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="qwenvl", api_key_info=api_key_info)
@v1_app.post("/qwenvl_analyze")
async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="qwenvl_analyze", api_key_info=api_key_info)
@v1_app.post("/cpm_analyze")
async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="cpm_analyze", api_key_info=api_key_info)
@v1_app.post("/yolo")
async def upload_yolo(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="yolo", api_key_info=api_key_info)
@v1_app.post("/fall")
async def upload_fall(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="fall", api_key_info=api_key_info)
@v1_app.post("/face")
async def upload_face(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="face", api_key_info=api_key_info)
@v1_app.post("/mediapipe")
async def upload_mediapipe(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="mediapipe", api_key_info=api_key_info)
@v1_app.post("/compare")
async def upload_compare(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
return await upload_file(file, task_type="compare", api_key_info=api_key_info)
@v1_app.get("/result/{task_id}")
async def get_result(task_id: str, api_key: str = Depends(get_api_key)):
api_key_info = await verify_api_key(api_key)
if not api_key_info:
raise HTTPException(status_code=403, detail="无效的API密钥")
# 从 REDIS_DB (15) 获取任务状态
task_info = redis_client.hgetall(f"task:{task_id}")
if not task_info:
raise HTTPException(status_code=404, detail="Task not found")
task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()}
if task_info['status'] != 'completed':
return {"status": task_info['status'], "message": "Task is not completed yet"}
result_type = task_info['result_type']
result_key = task_info['result_key']
# 根据任务类型选择相应的 Redis 客户端
redis_client_map = {
'pose': redis_pose_client,
'cpm': redis_cpm_client,
'yolo': redis_yolo_client,
'face': redis_face_client,
'fall': redis_fall_client,
'mediapipe': redis_mediapipe_client,
'qwenvl': redis_qwenvl_client,
'compare': redis_compare_client,
'qwenvl_analyze': redis_qwenvl_analyze_client,
'cpm_analyze': redis_cpm_analyze_client
}
result_redis = redis_client_map.get(result_type)
if not result_redis:
raise HTTPException(status_code=400, detail="Unsupported result type")
result = result_redis.hgetall(result_key)
if not result:
raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found")
result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
# 将 result 字段解析为 JSON(如果存在)
if 'result' in result:
result['result'] = json.loads(result['result'])
return {
"status": "completed",
"result_type": result_type,
"result": result
}
@v1_app.get("/annotated/{task_id}")
async def get_annotated_image(task_id: str, api_key: str = Depends(get_api_key)):
api_key_info = await verify_api_key(api_key)
if not api_key_info:
raise HTTPException(status_code=403, detail="无效的API密钥")
# 从 REDIS_DB (15) 获取任务信息
task_info = redis_client.hgetall(f"task:{task_id}")
if not task_info:
raise HTTPException(status_code=404, detail="Task not found")
task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()}
if task_info['status'] != 'completed':
raise HTTPException(status_code=400, detail="Task is not completed yet")
result_type = task_info.get('result_type')
result_key = task_info.get('result_key')
if not result_key:
raise HTTPException(status_code=404, detail="Result key not found")
if result_type in ['cpm', 'qwenvl','cpm_analyze', 'qwenvl_analyze']:
raise HTTPException(status_code=400, detail="Annotated image not available for this task type")
# 根据任务类型选择相应的 Redis 客户端
redis_client_map = {
'pose': redis_pose_client,
'yolo': redis_yolo_client,
'face': redis_face_client,
'fall': redis_fall_client,
'mediapipe': redis_mediapipe_client,
'compare': redis_compare_client
}
result_redis = redis_client_map.get(result_type)
if not result_redis:
raise HTTPException(status_code=400, detail="Unsupported result type")
result = result_redis.hgetall(result_key)
if not result:
raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found")
result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
result_file = result.get('result_file')
if not result_file:
raise HTTPException(status_code=404, detail="Result file not found")
file_path = os.path.join(RESULT_DIR, result_file)
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="Result image file not found")
return FileResponse(file_path, media_type="image/png")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8005)
+12
View File
@@ -0,0 +1,12 @@
fastapi
uvicorn
python-multipart
kafka-python
redis
pillow
decord
pydantic
requests
python-jose[cryptography]
passlib[bcrypt]
sqlalchemy
+94
View File
@@ -0,0 +1,94 @@
# Kafka 配置
KAFKA_BROKER=222.186.10.253:9092
KAFKA_ASR_TOPIC=asr
KAFKA_CHAT_TOPIC=chat
KAFKA_TTS_TOPIC=tts
# Redis 配置
REDIS_HOST=150.158.144.159
REDIS_PORT=13003
REDIS_ASR_DB=12
REDIS_CHAT_DB=13
REDIS_TTS_DB=14
REDIS_PASSWORD=Obscura@2024
REDIS_API_DB=2
REDIS_API_USAGE_DB=3
REDIS_TASK_DB=11
REDIS_SESSION_DB=63
REDIS_SESSION_DB_ZH=63
REDIS_SESSION_DB_EN=62
REDIS_SESSION_DB_KO=61
# CORS 配置
# ALLOWED_ORIGINS=https://beta.obscura.work
# GPT-SoVITS 配置
GPT_MODEL_PATH=GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
SOVITS_MODEL_PATH=GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
REF_AUDIO_PATH=sample/woman.wav
REF_TEXT_PATH=sample/woman.txt
REF_LANGUAGE=中文
TARGET_LANGUAGE=多语种混合
OUTPUT_PATH=/obscura/task/audio_files
# VOICE_CONFIGS
GIRL_REF_AUDIO=sample/gril.wav
GIRL_REF_TEXT=sample/gril.txt
WOMAN_REF_AUDIO=sample/woman.wav
WOMAN_REF_TEXT=sample/woman.txt
MAN_REF_AUDIO=sample/man.wav
MAN_REF_TEXT=sample/man.txt
LEIJUN_REF_AUDIO=sample/leijun.wav
LEIJUN_REF_TEXT=sample/leijun.txt
DUFU_REF_AUDIO=sample/dufu.wav
DUFU_REF_TEXT=sample/dufu.txt
HEJIONG_REF_AUDIO=sample/hejiong.wav
HEJIONG_REF_TEXT=sample/hejiong.txt
MAHUATENG_REF_AUDIO=sample/mahuateng.wav
MAHUATENG_REF_TEXT=sample/mahuateng.txt
LIDAN_REF_AUDIO=sample/lidan.wav
LIDAN_REF_TEXT=sample/lidan.txt
YUHUA_REF_AUDIO=sample/yuhua.wav
YUHUA_REF_TEXT=sample/yuhua.txt
LIUZHENYUN_REF_AUDIO=sample/liuzhenyun.wav
LIUZHENYUN_REF_TEXT=sample/liuzhenyun.txt
DABING_REF_AUDIO=sample/dabing.wav
DABING_REF_TEXT=sample/dabing.txt
LUOXIANG_REF_AUDIO=sample/luoxiang.wav
LUOXIANG_REF_TEXT=sample/luoxiang.txt
XUZHIYUAN_REF_AUDIO=sample/xuzhiyuan.wav
XUZHIYUAN_REF_TEXT=sample/xuzhiyuan.txt
REDIS_GIRL_DB = 15
REDIS_WOMAN_DB = 16
REDIS_MAN_DB = 17
REDIS_LEIJUN_DB = 18
REDIS_DUFU_DB = 19
REDIS_HEJIONG_DB = 20
REDIS_MAHUATENG_DB = 21
REDIS_LIDAN_DB = 22
REDIS_DABING_DB = 23
REDIS_LUOXIANG_DB = 24
REDIS_XUZHIYUAN_DB = 25
REDIS_YUHUA_DB = 26
REDIS_LIUZHENYUN_DB = 27
+406
View File
@@ -0,0 +1,406 @@
from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader
from fastapi.responses import FileResponse
from pydantic import BaseModel
from kafka import KafkaProducer
from redis import Redis
import os
import json
import uuid
from datetime import datetime, timezone
from dotenv import load_dotenv
import tempfile
import hashlib
from pydantic import BaseModel, Field
# 在文件顶部添加这个函数
def get_audio_hash(text):
return hashlib.md5(text.encode()).hexdigest()
# 加载 .env 文件
load_dotenv()
app = FastAPI()
v1_chat_app = FastAPI()
app.mount("/v1_chat", v1_chat_app)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 配置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB'))
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
REDIS_API_DB = int(os.getenv('REDIS_API_DB'))
REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB'))
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
# Redis 配置
REDIS_GIRL_DB = int(os.getenv('REDIS_GIRL_DB'))
REDIS_WOMAN_DB = int(os.getenv('REDIS_WOMAN_DB'))
REDIS_MAN_DB = int(os.getenv('REDIS_MAN_DB'))
REDIS_LEIJUN_DB = int(os.getenv('REDIS_LEIJUN_DB'))
REDIS_DUFU_DB = int(os.getenv('REDIS_DUFU_DB'))
REDIS_HEJIONG_DB = int(os.getenv('REDIS_HEJIONG_DB'))
REDIS_MAHUATENG_DB = int(os.getenv('REDIS_MAHUATENG_DB'))
REDIS_LIDAN_DB = int(os.getenv('REDIS_LIDAN_DB'))
REDIS_YUHUA_DB = int(os.getenv('REDIS_YUHUA_DB'))
REDIS_LIUZHENYUN_DB = int(os.getenv('REDIS_LIUZHENYUN_DB'))
REDIS_DABING_DB = int(os.getenv('REDIS_DABING_DB'))
REDIS_LUOXIANG_DB = int(os.getenv('REDIS_LUOXIANG_DB'))
REDIS_XUZHIYUAN_DB = int(os.getenv('REDIS_XUZHIYUAN_DB'))
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
OUTPUT_PATH= os.getenv('OUTPUT_PATH')
# 初始化 Kafka Producer
producer = KafkaProducer(
bootstrap_servers=[KAFKA_BROKER],
value_serializer=lambda v: json.dumps(v).encode('utf-8')
)
# 初始化 Redis
redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB)
redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB)
redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB)
redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB)
redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB)
redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB)
redis_tts_girl = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_GIRL_DB)
redis_tts_woman = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_WOMAN_DB)
redis_tts_man = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAN_DB)
redis_tts_leijun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LEIJUN_DB)
redis_tts_dufu = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DUFU_DB)
redis_tts_hejiong = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_HEJIONG_DB)
redis_tts_mahuateng = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAHUATENG_DB)
redis_tts_lidan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIDAN_DB)
redis_tts_yuhua = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_YUHUA_DB)
redis_tts_liuzhenyun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIUZHENYUN_DB)
redis_tts_dabing = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DABING_DB)
redis_tts_luoxiang = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LUOXIANG_DB)
redis_tts_xuzhiyuan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_XUZHIYUAN_DB)
# 创建一个音色到对应 Redis 客户端的映射
voice_to_redis = {
"default": redis_tts_girl,
"girl": redis_tts_girl,
"woman": redis_tts_woman,
"man": redis_tts_man,
"leijun": redis_tts_leijun,
"dufu": redis_tts_dufu,
"hejiong": redis_tts_hejiong,
"mahuateng": redis_tts_mahuateng,
"lidan": redis_tts_lidan,
"yuhua": redis_tts_yuhua,
"liuzhenyun": redis_tts_liuzhenyun,
"dabing": redis_tts_dabing,
"luoxiang": redis_tts_luoxiang,
"xuzhiyuan": redis_tts_xuzhiyuan
}
# 定义API密钥头部
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
def get_audio_hash(text):
return hashlib.md5(text.encode()).hexdigest()
# 验证API密钥的函数
async def get_api_key(api_key: str = Security(api_key_header)):
if api_key and api_key.startswith("Bearer "):
key = api_key.split(" ")[1]
if key.startswith("obs-"):
return key
raise HTTPException(
status_code=401,
detail="无效的API密钥",
headers={"WWW-Authenticate": "Bearer"},
)
async def verify_api_key(api_key: str = Depends(get_api_key)):
redis_key = f"api_key:{api_key}"
api_key_info = redis_api_client.hgetall(redis_key)
if not api_key_info:
raise HTTPException(status_code=401, detail="无效的API密钥")
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
if api_key_info.get('is_active') != '1':
raise HTTPException(status_code=401, detail="API密钥已停用")
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
if datetime.now(timezone.utc) > expires_at:
raise HTTPException(status_code=401, detail="API密钥已过期")
usage_info = redis_api_usage_client.hgetall(redis_key)
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
return {
"api_key": api_key,
**api_key_info,
**usage_info
}
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
redis_key = f"api_key:{api_key}"
current_time = datetime.now(timezone.utc).isoformat()
pipe = redis_api_usage_client.pipeline()
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
pipe.hset(redis_key, "last_used_at", current_time)
model_tokens_field = f"{model_name}_tokens_used"
model_last_used_field = f"{model_name}_last_used_at"
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
pipe.hset(redis_key, model_last_used_field, current_time)
pipe.execute()
async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str):
api_key = api_key_info['api_key']
usage_key = f"api_key:{api_key}"
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
if tokens_used + tokens_required > total_tokens:
raise HTTPException(status_code=403, detail="Token 余额不足")
# 更新 token 使用量
await update_token_usage(api_key, tokens_required, model_name)
# 发送任务到Kafka
producer.send(kafka_topic, task_data)
# 获取更新后的 token 使用情况
updated_api_key_info = await verify_api_key(api_key)
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
return {
"message": f"{model_name.upper()}请求已排队等待处理",
"tokens_used": tokens_required,
"total_tokens_used": new_tokens_used,
f"{model_name}_tokens_used": model_tokens_used,
"tokens_remaining": total_tokens - new_tokens_used
}
class TTSRequest(BaseModel):
text: str
voice: str = Field(..., description="选择的音色")
class ChatRequest(BaseModel):
session_id: str
query: str
model: str = "qwen2.5:3b"
@v1_chat_app.post("/tts")
async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
text_hash = get_audio_hash(request.text)
# 验证音色选择
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
if request.voice not in valid_voices:
raise HTTPException(status_code=400, detail="无效的音色选择")
# 如果声音是 'default',则将其视为 'girl'
voice = 'girl' if request.voice == 'default' else request.voice
# 使用对应音色的 Redis 客户端
redis_tts = voice_to_redis[request.voice]
# 检查是否已存在相同内容的音频文件
existing_audio_info = redis_tts.get(f"tts:{text_hash}")
if existing_audio_info:
existing_audio_path = json.loads(existing_audio_info)['path']
if os.path.exists(existing_audio_path):
return {
"message": "TTS请求已完成",
"task_id": task_id,
"status": "completed",
"audio_path": existing_audio_path
}
# 如果不存在,创建新的任务
task_data = {
"task_id": task_id,
"text": request.text,
"text_hash": text_hash,
"voice": request.voice,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
}
# 存储任务信息到Redis
redis_task_client.set(f"task_status:tts:{task_id}", "queued")
result = await process_request(api_key_info, "tts", 1, task_data, KAFKA_TTS_TOPIC)
result["task_id"] = task_id
return result
@v1_chat_app.post("/asr")
async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
UPLOAD_DIR = "/obscura/task/audio_upload"
os.makedirs(UPLOAD_DIR, exist_ok=True)
file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav")
with open(file_path, "wb") as temp_audio:
content = await audio.read()
temp_audio.write(content)
task_data = {
'file_path': file_path,
'task_id': task_id,
'status': 'queued'
}
# 存储任务状态,使用一致的键名格式
redis_task_client.set(f"task_status:asr:{task_id}", "queued")
result = await process_request(api_key_info, "asr", 1, task_data, KAFKA_ASR_TOPIC)
result["task_id"] = task_id
return result
@v1_chat_app.post("/chat")
async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
task_data = {
"task_id": task_id,
"session_id": request.session_id,
"query": request.query,
"model": request.model,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
}
# 设置任务状态为 "queued"
redis_task_client.set(f"chat:{task_id}:status", "queued")
result = await process_request(api_key_info, "chat", 1, task_data, KAFKA_CHAT_TOPIC)
result["task_id"] = task_id
return result
@v1_chat_app.get("/chat_result/{task_id}")
async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
# 从Redis任务数据库获取任务状态
task_status = redis_task_client.get(f"chat:{task_id}:status")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从Redis任务数据库获取聊天结果
chat_result = redis_task_client.get(f"chat:{task_id}:result")
if chat_result:
result = json.loads(chat_result)
return {
"status": "completed",
"result": result
}
return {"status": status}
return {"status": "not_found"}
@v1_chat_app.get("/tts_result/{task_id}")
async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
if task_info:
task_data = json.loads(task_info)
text_hash = task_data['text_hash']
voice = task_data['voice']
# 'default' 和 'girl' 都使用 girl 的 Redis
redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice]
audio_info = redis_tts.get(f"tts:{text_hash}")
if audio_info:
audio_path = json.loads(audio_info)['path']
return {
"status": "completed",
"audio_path": audio_path
}
return {"status": status}
return {"status": "not_found"}
@v1_chat_app.get("/asr_result/{task_id}")
async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
# 从Redis任务数据库获取任务状态,使用一致的键名格式
task_status = redis_task_client.get(f"task_status:asr:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从Redis ASR结果数据库获取转录结果
transcription = redis_asr_client.get(f"asr:{task_id}")
if transcription:
return {
"status": "completed",
"transcription": transcription.decode('utf-8')
}
return {"status": status}
return {"status": "not_found"}
@v1_chat_app.get("/tts_audio/{task_id}")
async def get_tts_audio(task_id: str, api_key_info: dict = Depends(verify_api_key)):
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从任务信息中获取使用的音色
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
if task_info:
task_data = json.loads(task_info)
voice = task_data.get('voice', 'girl') # 默认使用 'girl'
# 'default' 和 'girl' 都使用 girl 的 Redis
redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice]
# 从对应音色的 Redis 获取音频文件路径
audio_info = redis_tts.get(f"tts:{task_data['text_hash']}")
if audio_info:
audio_path = json.loads(audio_info)['path']
if os.path.exists(audio_path):
file_name = os.path.basename(audio_path)
return FileResponse(audio_path, media_type="audio/wav", filename=file_name)
else:
raise HTTPException(status_code=404, detail="音频文件不存在")
elif status == "queued" or status == "processing":
raise HTTPException(status_code=202, detail="音频文件正在生成中")
else:
raise HTTPException(status_code=500, detail="任务处理失败")
raise HTTPException(status_code=404, detail="任务不存在")
@v1_chat_app.get("/getvoice")
async def get_available_voices(api_key_info: dict = Depends(verify_api_key)):
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
return {"available_voices": valid_voices}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8008)
+7
View File
@@ -0,0 +1,7 @@
fastapi
uvicorn
pydantic
kafka-python
redis
python-dotenv
python-multipart
View File