update
This commit is contained in:
@@ -18,3 +18,4 @@ api_chat/runtime/*
|
|||||||
|
|
||||||
api_chat/docs/*
|
api_chat/docs/*
|
||||||
!api_chat/docs/.gitkeep
|
!api_chat/docs/.gitkeep
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,140 @@
|
|||||||
|
# API集合
|
||||||
|
|
||||||
|
该项目为所有API集合,集成了视觉分析、聊天对话和语音处理等功能。
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
API/
|
||||||
|
├── api/ # 视觉分析和处理模块
|
||||||
|
│ ├── producer/ # 主程序入口,生产者,分配任务
|
||||||
|
│ ├── cpm_analyze.py # CPM_OCR模型分析
|
||||||
|
│ ├── qwenvl_analyze.py # QwenVL_OCR模型分析
|
||||||
|
│ ├── cpm_scene.py # CPM_场景模型分析
|
||||||
|
│ ├── qwenvl_scene.py # QwenVL_场景模型分析
|
||||||
|
│ ├── compare.py # 人脸对比模型
|
||||||
|
│ ├── yolo.py # YOLO目标检测
|
||||||
|
│ ├── face.py # 人脸检测
|
||||||
|
│ ├── fall.py # 跌倒检测
|
||||||
|
│ ├── pose.py # 姿态估计
|
||||||
|
│ └── media.py # 媒体处理
|
||||||
|
├── api_chat/ # 聊天和语音处理模块
|
||||||
|
│ ├── producer_chat/ # 聊天生产者
|
||||||
|
│ ├── chat.py # 聊天功能
|
||||||
|
│ ├── tts.py # 文字转语音
|
||||||
|
│ ├── asr.py # 语音识别
|
||||||
|
│ ├── GPT_SoVITS/ # GPT_SoVITS模型集成,
|
||||||
|
│ ├── sample/ # OpenBMB模型——学习音色,音色+文本内容,
|
||||||
|
│ ├── tools/ # GPT_SoVITS模型——工具函数
|
||||||
|
│ ├── runtime/ # GPT_SoVITS模型——运行时函数
|
||||||
|
│ ├── docs/ # GPT_SoVITS模型——文档
|
||||||
|
│ ├── TEMP/ # OpenBMB模型临时文件夹,
|
||||||
|
│ └── before/ # 历史代码,可以忽略
|
||||||
|
├── api_history/ # api历史代码,可以忽略
|
||||||
|
├── chat_history/ # api_chat历史代码,可以忽略
|
||||||
|
└── api_old/ # api历史代码,可以忽略
|
||||||
|
```
|
||||||
|
|
||||||
|
## 主要功能
|
||||||
|
|
||||||
|
### 视觉分析模块 (api/)
|
||||||
|
- 目标检测和跟踪
|
||||||
|
- 人脸识别
|
||||||
|
- 人脸对比
|
||||||
|
- 姿态估计
|
||||||
|
- 跌倒检测
|
||||||
|
- 场景理解(基于CPM和QwenVL模型)
|
||||||
|
|
||||||
|
### 聊天对话模块 (api_chat/)
|
||||||
|
- 文本对话功能
|
||||||
|
- 语音识别 (ASR): 通过Whisper模型
|
||||||
|
- 文字转语音 (TTS): 通过GPT_SoVITS模型
|
||||||
|
- 多模型支持(通过Ollama)
|
||||||
|
|
||||||
|
## 使用说明
|
||||||
|
### API 部分 http://dev.obscura.work/v1
|
||||||
|
1. producer 目录 # 生产者,分配任务
|
||||||
|
2. 服务器:222.186.10.253:8005
|
||||||
|
3. kafka 配置:222.186.10.253:9092
|
||||||
|
topic分配:
|
||||||
|
- yolo: "yolo"
|
||||||
|
- pose: "pose"
|
||||||
|
- qwenvl: "qwenvl"
|
||||||
|
- qwenvl_analyze: "qwenvl_analyze"
|
||||||
|
- cpm: "cpm"
|
||||||
|
- cpm_analyze: "cpm_analyze"
|
||||||
|
- fall: "fall"
|
||||||
|
- face: "face"
|
||||||
|
- mediapipe: "mediapipe"
|
||||||
|
- compare: "compare"
|
||||||
|
|
||||||
|
4. redis 配置:150.158.144.159:13003
|
||||||
|
db分配:
|
||||||
|
- 4: "yolo"
|
||||||
|
- 5: "pose"
|
||||||
|
- 9: "qwenvl"
|
||||||
|
- 32: "qwenvl_analyze"
|
||||||
|
- 8: "cpm"
|
||||||
|
- 31: "cpm_analyze"
|
||||||
|
- 6: "fall"
|
||||||
|
- 7: "face"
|
||||||
|
- 10: "mediapipe"
|
||||||
|
- 30: "compare"
|
||||||
|
|
||||||
|
5. 模型配置:
|
||||||
|
- YOLO = "/obscura/models/yolov8x.pt"
|
||||||
|
- POSE = "/obscura/models/yolov8x-pose.pt"
|
||||||
|
- QWEN = "/obscura/models/qwen/Qwen2-VL-2B-Instruct"
|
||||||
|
- FALL = "/obscura/models/yolov8n-fall.pt"
|
||||||
|
- FACE = "/obscura/models/yolov8n-face.pt"
|
||||||
|
- MEDIAPIPE = "/obscura/models/face_landmarker.task"
|
||||||
|
- CPM(ollama) = "https://ffgregevrdcfyhtnhyudvr.myfastools.com/api/generate"
|
||||||
|
6. 上传文件及结果保存目录:
|
||||||
|
- UPLOAD_DIR = "/obscura/task/upload"
|
||||||
|
- RESULT_DIR = "/obscura/task/result"
|
||||||
|
|
||||||
|
### API_Chat 部分 http://dev.obscura.work/v1_chat
|
||||||
|
1. producer_chat 目录 # 聊天生产者
|
||||||
|
2. 服务器:222.186.10.253:8008
|
||||||
|
3. kafka 配置:222.186.10.253:9092
|
||||||
|
topic分配:
|
||||||
|
- asr: "asr"
|
||||||
|
- chat: "chat"
|
||||||
|
- tts: "tts"
|
||||||
|
4. redis 配置:150.158.144.159:13003
|
||||||
|
db分配:
|
||||||
|
- 2: "api key"
|
||||||
|
- 3: "api使用情况"
|
||||||
|
- 11: "task 任务记录"
|
||||||
|
- 12: "asr 记录"
|
||||||
|
- 13: "chat 记录"
|
||||||
|
- 14: "tts 记录"
|
||||||
|
#语言
|
||||||
|
- 63: "session_zh 中文"
|
||||||
|
- 62: "session_en 英文"
|
||||||
|
- 61: "session_ko 韩语"
|
||||||
|
#音色
|
||||||
|
- 15: "girl"
|
||||||
|
- 16: "woman"
|
||||||
|
- 17: "man"
|
||||||
|
- 18: "leijun"
|
||||||
|
- 19: "dufu"
|
||||||
|
- 20: "hejiong"
|
||||||
|
- 21: "mahuateng"
|
||||||
|
- 22: "lidan"
|
||||||
|
- 23: "dabing"
|
||||||
|
- 24: "luoxiang"
|
||||||
|
- 25: "xuzhiyuan"
|
||||||
|
- 26: "yuhua"
|
||||||
|
- 27: "liuzhenyun"
|
||||||
|
5. 音频文件保存目录:
|
||||||
|
- OUTPUT_PATH=/obscura/task/audio_files # 音频文件保存目录
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
- 注意redis的db分配
|
||||||
|
- 注意kafka的topic分配
|
||||||
|
- 注意producer的config.py环境配置
|
||||||
|
- 注意producer_chat的.env环境配置
|
||||||
|
- 确保模型权重文件已正确配置
|
||||||
|
- 检查API密钥和环境变量设置
|
||||||
|
- 注意资源使用和性能优化
|
||||||
@@ -1,97 +0,0 @@
|
|||||||
# 将本地Ollama API完全反向代理
|
|
||||||
from fastapi import FastAPI, Request
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
import httpx
|
|
||||||
OLLAMA_URL = "http://127.0.0.1:11434"
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
# 添加CORS中间件
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"], # 允许所有来源
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"], # 允许所有HTTP方法
|
|
||||||
allow_headers=["*"], # 允许所有HTTP头
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建异步HTTP客户端
|
|
||||||
async_client = httpx.AsyncClient()
|
|
||||||
|
|
||||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
|
||||||
async def proxy_to_ollama(request: Request, path: str):
|
|
||||||
if request.method == "OPTIONS":
|
|
||||||
# 处理预检请求
|
|
||||||
headers = {
|
|
||||||
"Access-Control-Allow-Origin": "*",
|
|
||||||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
|
||||||
"Access-Control-Allow-Headers": "*",
|
|
||||||
}
|
|
||||||
return StreamingResponse(content=iter([]), headers=headers)
|
|
||||||
|
|
||||||
target_url = f"{OLLAMA_URL}/{path}"
|
|
||||||
|
|
||||||
# 获取请求体
|
|
||||||
body = await request.body()
|
|
||||||
|
|
||||||
# 获取请求头
|
|
||||||
headers = dict(request.headers)
|
|
||||||
headers.pop("host", None)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 将请求转换为Python请求
|
|
||||||
python_request = {
|
|
||||||
"method": request.method,
|
|
||||||
"url": target_url,
|
|
||||||
"headers": headers,
|
|
||||||
"data": body
|
|
||||||
}
|
|
||||||
|
|
||||||
# 使用Python请求发送到Ollama API
|
|
||||||
async with async_client.stream(**python_request) as response:
|
|
||||||
# 返回Ollama API的流式响应,并添加CORS头
|
|
||||||
response_headers = dict(response.headers)
|
|
||||||
response_headers["Access-Control-Allow-Origin"] = "*"
|
|
||||||
return StreamingResponse(
|
|
||||||
response.aiter_raw(),
|
|
||||||
status_code=response.status_code,
|
|
||||||
headers=response_headers
|
|
||||||
)
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
return {"error": f"请求Ollama API时发生错误: {str(e)}"}, 500
|
|
||||||
except httpx.StreamClosed:
|
|
||||||
# 处理流关闭异常
|
|
||||||
print("流已关闭,客户端可能已断开连接")
|
|
||||||
return {"error": "流已关闭,客户端可能已断开连接"}, 499
|
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
|
||||||
async def shutdown_event():
|
|
||||||
await async_client.aclose()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
import requests
|
|
||||||
import json
|
|
||||||
|
|
||||||
# 测试Ollama API
|
|
||||||
test_url = "http://localhost:11434/api/generate"
|
|
||||||
test_data = {
|
|
||||||
"model": "llama3.1",
|
|
||||||
"prompt": "Why is the sky blue?",
|
|
||||||
"stream": False
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.post(test_url, json=test_data)
|
|
||||||
if response.status_code == 200:
|
|
||||||
print("Ollama API 测试成功:")
|
|
||||||
print(json.dumps(response.json(), indent=2))
|
|
||||||
else:
|
|
||||||
print(f"Ollama API 测试失败,状态码: {response.status_code}")
|
|
||||||
print(response.text)
|
|
||||||
except requests.RequestException as e:
|
|
||||||
print(f"Ollama API 测试出错: {str(e)}")
|
|
||||||
|
|
||||||
# 启动FastAPI应用
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
# config.py
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Kafka配置
|
||||||
|
KAFKA_BROKER = "222.186.10.253:9092"
|
||||||
|
KAFKA_GROUP_ID_PREFIX = "group"
|
||||||
|
|
||||||
|
# Redis配置
|
||||||
|
REDIS_HOST = "150.158.144.159"
|
||||||
|
REDIS_PORT = 13003
|
||||||
|
REDIS_PASSWORD = "Obscura@2024"
|
||||||
|
MAIN_REDIS_DB = 0
|
||||||
|
REDIS_API_DB = 2
|
||||||
|
REDIS_API_USAGE_DB = 3
|
||||||
|
# 目录配置
|
||||||
|
UPLOAD_DIR = "/obscura/task/upload"
|
||||||
|
RESULT_DIR = "/obscura/task/result"
|
||||||
|
|
||||||
|
# 确保目录存在
|
||||||
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||||
|
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
YOLO_MODEL_PATH = "/obscura/models/yolov8x.pt"
|
||||||
|
POSE_MODEL_PATH = "/obscura/models/yolov8x-pose.pt"
|
||||||
|
QWEN_MODEL_PATH = "/obscura/models/qwen/Qwen2-VL-2B-Instruct"
|
||||||
|
FALL_MODEL_PATH = "/obscura/models/yolov8n-fall.pt"
|
||||||
|
FACE_MODEL_PATH = "/obscura/models/yolov8n-face.pt"
|
||||||
|
MEDIAPIPE_MODEL_PATH = "/obscura/models/face_landmarker.task"
|
||||||
|
# Ollama配置
|
||||||
|
OLLAMA_URL = "https://ffgregevrdcfyhtnhyudvr.myfastools.com/api/generate"
|
||||||
|
|
||||||
|
# 各个worker的配置
|
||||||
|
WORKER_CONFIGS = {
|
||||||
|
"yolo": {
|
||||||
|
"kafka_topic": "yolo",
|
||||||
|
"redis_db": 4,
|
||||||
|
},
|
||||||
|
"pose": {
|
||||||
|
"kafka_topic": "pose",
|
||||||
|
"redis_db": 5,
|
||||||
|
},
|
||||||
|
"qwenvl": {
|
||||||
|
"kafka_topic": "qwenvl",
|
||||||
|
"redis_db": 9,
|
||||||
|
},
|
||||||
|
"qwenvl_analyze": {
|
||||||
|
"kafka_topic": "qwenvl_analyze",
|
||||||
|
"redis_db": 32,
|
||||||
|
},
|
||||||
|
"cpm": {
|
||||||
|
"kafka_topic": "cpm",
|
||||||
|
"redis_db": 8,
|
||||||
|
},
|
||||||
|
"cpm_analyze": {
|
||||||
|
"kafka_topic": "cpm_analyze",
|
||||||
|
"redis_db": 31,
|
||||||
|
},
|
||||||
|
"fall": {
|
||||||
|
"kafka_topic": "fall",
|
||||||
|
"redis_db": 6,
|
||||||
|
},
|
||||||
|
"face": {
|
||||||
|
"kafka_topic": "face",
|
||||||
|
"redis_db": 7,
|
||||||
|
},
|
||||||
|
"mediapipe": {
|
||||||
|
"kafka_topic": "mediapipe",
|
||||||
|
"redis_db": 10,
|
||||||
|
},
|
||||||
|
"compare": {
|
||||||
|
"kafka_topic": "compare",
|
||||||
|
"redis_db": 30,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# GPU设置
|
||||||
|
CUDA_DEVICE_0 = "cuda:0"
|
||||||
|
CUDA_DEVICE_1 = "cuda:1"
|
||||||
@@ -0,0 +1,442 @@
|
|||||||
|
# main.py
|
||||||
|
from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Security
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
from kafka import KafkaProducer
|
||||||
|
from redis import Redis
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
import string
|
||||||
|
from decord import VideoReader
|
||||||
|
from PIL import Image
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
import logging
|
||||||
|
from config import *
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
v1_app = FastAPI()
|
||||||
|
app.mount("/v1", v1_app)
|
||||||
|
|
||||||
|
|
||||||
|
# CORS设置
|
||||||
|
# ALLOWED_ORIGINS = ['https://beta.obscura.work']
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
KAFKA_BROKER = KAFKA_BROKER
|
||||||
|
REDIS_HOST = REDIS_HOST
|
||||||
|
REDIS_PORT = REDIS_PORT
|
||||||
|
REDIS_PASSWORD = REDIS_PASSWORD
|
||||||
|
REDIS_DB = MAIN_REDIS_DB
|
||||||
|
REDIS_API_DB = REDIS_API_DB
|
||||||
|
REDIS_API_USAGE_DB = REDIS_API_USAGE_DB
|
||||||
|
UPLOAD_DIR = UPLOAD_DIR
|
||||||
|
RESULT_DIR = RESULT_DIR
|
||||||
|
MAX_FILE_AGE = timedelta(hours=1)
|
||||||
|
|
||||||
|
# 确保目录存在
|
||||||
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||||
|
os.makedirs(RESULT_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# 定义支持的任务类型
|
||||||
|
KAFKA_TOPICS = {
|
||||||
|
'pose': 'pose',
|
||||||
|
'mediapipe': 'mediapipe',
|
||||||
|
'qwenvl': 'qwenvl',
|
||||||
|
'yolo': 'yolo',
|
||||||
|
'fall': 'fall',
|
||||||
|
'face': 'face',
|
||||||
|
'cpm': 'cpm',
|
||||||
|
'compare': 'compare',
|
||||||
|
'qwenvl_analyze': 'qwenvl_analyze',
|
||||||
|
'cpm_analyze': 'cpm_analyze'
|
||||||
|
}
|
||||||
|
|
||||||
|
TASK_TYPES = list(KAFKA_TOPICS.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# 初始化 Kafka Producer
|
||||||
|
producer = KafkaProducer(
|
||||||
|
bootstrap_servers=[KAFKA_BROKER],
|
||||||
|
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化 Redis
|
||||||
|
redis_client = Redis(
|
||||||
|
host=REDIS_HOST,
|
||||||
|
port=REDIS_PORT,
|
||||||
|
password=REDIS_PASSWORD,
|
||||||
|
db=REDIS_DB
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_api_client = Redis(
|
||||||
|
host=REDIS_HOST,
|
||||||
|
port=REDIS_PORT,
|
||||||
|
password=REDIS_PASSWORD,
|
||||||
|
db=REDIS_API_DB
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_api_usage_client = Redis(
|
||||||
|
host=REDIS_HOST,
|
||||||
|
port=REDIS_PORT,
|
||||||
|
password=REDIS_PASSWORD,
|
||||||
|
db=REDIS_API_USAGE_DB
|
||||||
|
)
|
||||||
|
redis_pose_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['pose']['redis_db'])
|
||||||
|
redis_cpm_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm']['redis_db'])
|
||||||
|
redis_yolo_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['yolo']['redis_db'])
|
||||||
|
redis_face_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['face']['redis_db'])
|
||||||
|
redis_fall_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['fall']['redis_db'])
|
||||||
|
redis_mediapipe_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['mediapipe']['redis_db'])
|
||||||
|
redis_qwenvl_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl']['redis_db'])
|
||||||
|
redis_compare_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['compare']['redis_db'])
|
||||||
|
redis_qwenvl_analyze_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl_analyze']['redis_db'])
|
||||||
|
redis_cpm_analyze_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm_analyze']['redis_db'])
|
||||||
|
|
||||||
|
@v1_app.get('/favicon.ico', include_in_schema=False)
|
||||||
|
async def favicon():
|
||||||
|
file_name = "favicon.ico"
|
||||||
|
file_path = os.path.join(app.root_path, "static", file_name)
|
||||||
|
if os.path.isfile(file_path):
|
||||||
|
return FileResponse(file_path)
|
||||||
|
else:
|
||||||
|
return {"message": "Favicon not found"}, 404
|
||||||
|
|
||||||
|
# 定义API密钥头部
|
||||||
|
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||||
|
|
||||||
|
# 定义 base62 字符集
|
||||||
|
BASE62 = string.digits + string.ascii_letters
|
||||||
|
|
||||||
|
# 验证API密钥的函数
|
||||||
|
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||||
|
if api_key and api_key.startswith("Bearer "):
|
||||||
|
key = api_key.split(" ")[1]
|
||||||
|
if key.startswith("obs-"):
|
||||||
|
return key
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="无效的API密钥",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||||
|
logging.info(f"验证API密钥: {api_key}")
|
||||||
|
redis_key = f"api_key:{api_key}"
|
||||||
|
|
||||||
|
api_key_info = redis_api_client.hgetall(redis_key)
|
||||||
|
|
||||||
|
if not api_key_info:
|
||||||
|
logging.warning(f"API密钥不存在: {api_key}")
|
||||||
|
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||||
|
|
||||||
|
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||||
|
|
||||||
|
if api_key_info.get('is_active') != '1':
|
||||||
|
logging.warning(f"API密钥已停用: {api_key}")
|
||||||
|
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||||
|
|
||||||
|
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||||
|
if datetime.now(timezone.utc) > expires_at:
|
||||||
|
logging.warning(f"API密钥已过期: {api_key}")
|
||||||
|
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||||
|
|
||||||
|
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||||
|
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||||
|
|
||||||
|
logging.info(f"API密钥验证成功: {api_key}")
|
||||||
|
return {
|
||||||
|
"api_key": api_key,
|
||||||
|
**api_key_info,
|
||||||
|
**usage_info
|
||||||
|
}
|
||||||
|
|
||||||
|
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||||
|
redis_key = f"api_key:{api_key}"
|
||||||
|
current_time = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
pipe = redis_api_usage_client.pipeline()
|
||||||
|
|
||||||
|
# 更新总的token使用量
|
||||||
|
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||||
|
pipe.hset(redis_key, "last_used_at", current_time)
|
||||||
|
|
||||||
|
# 更新特定模型的token使用量
|
||||||
|
model_tokens_field = f"{model_name}_tokens_used"
|
||||||
|
model_last_used_field = f"{model_name}_last_used_at"
|
||||||
|
|
||||||
|
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||||
|
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||||
|
|
||||||
|
pipe.execute()
|
||||||
|
|
||||||
|
def calculate_tokens(file_path: str, file_type: str) -> int:
|
||||||
|
base_tokens = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_size = os.path.getsize(file_path) # 获取文件大小(字节)
|
||||||
|
|
||||||
|
# 基础token:每MB文件大小消耗10个token
|
||||||
|
base_tokens = int((file_size / (1024 * 1024)) * 0.1)
|
||||||
|
|
||||||
|
if file_type == "image":
|
||||||
|
img = Image.open(file_path)
|
||||||
|
width, height = img.size
|
||||||
|
pixel_count = width * height
|
||||||
|
|
||||||
|
# 图片token:每100个像素额外消耗5个token
|
||||||
|
image_tokens = int((pixel_count / 100000000) * 0.1)
|
||||||
|
|
||||||
|
base_tokens += image_tokens
|
||||||
|
|
||||||
|
elif file_type == "video":
|
||||||
|
vr = VideoReader(file_path)
|
||||||
|
fps = vr.get_avg_fps()
|
||||||
|
frame_count = len(vr)
|
||||||
|
width, height = vr[0].shape[1], vr[0].shape[0]
|
||||||
|
|
||||||
|
pixel_count = width * height * frame_count
|
||||||
|
duration = frame_count / fps # 视频时长(秒)
|
||||||
|
|
||||||
|
# 视频token:每100万像素每秒额外消耗1个token
|
||||||
|
video_tokens = int((pixel_count / 100000000) * (duration / 60))
|
||||||
|
|
||||||
|
base_tokens += video_tokens
|
||||||
|
|
||||||
|
return max(1, base_tokens) # 确保至少返回1个token
|
||||||
|
except Exception as e:
|
||||||
|
print(f"计算token时出错: {str(e)}")
|
||||||
|
return 1 # 出错时返回默认值1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def upload_file(file: UploadFile, task_type: str, api_key_info: dict):
|
||||||
|
if task_type not in KAFKA_TOPICS:
|
||||||
|
raise HTTPException(status_code=400, detail="不支持的任务类型")
|
||||||
|
|
||||||
|
content = await file.read()
|
||||||
|
file_extension = os.path.splitext(file.filename)[1].lower()
|
||||||
|
new_filename = f"{uuid.uuid4()}{file_extension}"
|
||||||
|
|
||||||
|
file_path = os.path.join(UPLOAD_DIR, new_filename)
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
# 计算 token
|
||||||
|
file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video"
|
||||||
|
tokens_required = calculate_tokens(file_path, file_type)
|
||||||
|
|
||||||
|
# 检查并更新 token 使用量
|
||||||
|
api_key = api_key_info['api_key']
|
||||||
|
usage_key = f"api_key:{api_key}"
|
||||||
|
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||||
|
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||||
|
|
||||||
|
if tokens_used + tokens_required > total_tokens:
|
||||||
|
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||||
|
|
||||||
|
# 更新 token 使用量
|
||||||
|
await update_token_usage(api_key, tokens_required, task_type)
|
||||||
|
|
||||||
|
# 创建任务记录
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
task_data = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"filename": new_filename,
|
||||||
|
"file_type": file_type,
|
||||||
|
"task_type": task_type,
|
||||||
|
"status": "queued",
|
||||||
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 存储任务信息到Redis
|
||||||
|
redis_client.set(f"task:{task_id}", json.dumps(task_data))
|
||||||
|
logging.info(f"任务信息已存储到Redis: {task_id}")
|
||||||
|
|
||||||
|
# 发送任务到对应的Kafka主题
|
||||||
|
kafka_topic = KAFKA_TOPICS[task_type]
|
||||||
|
producer.send(kafka_topic, task_data)
|
||||||
|
logging.info(f"任务已发送到Kafka主题: {kafka_topic}")
|
||||||
|
|
||||||
|
# 获取更新后的 token 使用情况
|
||||||
|
updated_api_key_info = await verify_api_key(api_key)
|
||||||
|
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||||
|
model_tokens_used = int(updated_api_key_info.get(f"{task_type}_tokens_used", 0))
|
||||||
|
|
||||||
|
response_data = {
|
||||||
|
"message": "文件已上传并排队等待处理",
|
||||||
|
"task_id": task_id,
|
||||||
|
"tokens_used": tokens_required,
|
||||||
|
"total_tokens_used": new_tokens_used,
|
||||||
|
f"{task_type}_tokens_used": model_tokens_used,
|
||||||
|
"tokens_remaining": total_tokens - new_tokens_used
|
||||||
|
}
|
||||||
|
logging.info(f"上传文件完成: {task_id}")
|
||||||
|
return JSONResponse(content=response_data)
|
||||||
|
|
||||||
|
# 为每个任务类型创建单独的端点
|
||||||
|
@v1_app.post("/pose")
|
||||||
|
async def upload_pose(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
logging.info(f"收到 /pose端点的请求")
|
||||||
|
return await upload_file(file, task_type="pose", api_key_info=api_key_info)
|
||||||
|
|
||||||
|
@v1_app.post("/cpm")
|
||||||
|
async def upload_cpm(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
return await upload_file(file, task_type="cpm", api_key_info=api_key_info)
|
||||||
|
|
||||||
|
@v1_app.post("/qwenvl")
|
||||||
|
async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
return await upload_file(file, task_type="qwenvl", api_key_info=api_key_info)
|
||||||
|
|
||||||
|
@v1_app.post("/qwenvl_analyze")
|
||||||
|
async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
return await upload_file(file, task_type="qwenvl_analyze", api_key_info=api_key_info)
|
||||||
|
|
||||||
|
@v1_app.post("/cpm_analyze")
|
||||||
|
async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
return await upload_file(file, task_type="cpm_analyze", api_key_info=api_key_info)
|
||||||
|
|
||||||
|
|
||||||
|
@v1_app.post("/yolo")
|
||||||
|
async def upload_yolo(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
return await upload_file(file, task_type="yolo", api_key_info=api_key_info)
|
||||||
|
|
||||||
|
@v1_app.post("/fall")
|
||||||
|
async def upload_fall(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
return await upload_file(file, task_type="fall", api_key_info=api_key_info)
|
||||||
|
|
||||||
|
@v1_app.post("/face")
|
||||||
|
async def upload_face(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
return await upload_file(file, task_type="face", api_key_info=api_key_info)
|
||||||
|
|
||||||
|
@v1_app.post("/mediapipe")
|
||||||
|
async def upload_mediapipe(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
return await upload_file(file, task_type="mediapipe", api_key_info=api_key_info)
|
||||||
|
|
||||||
|
@v1_app.post("/compare")
|
||||||
|
async def upload_compare(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
return await upload_file(file, task_type="compare", api_key_info=api_key_info)
|
||||||
|
|
||||||
|
|
||||||
|
@v1_app.get("/result/{task_id}")
|
||||||
|
async def get_result(task_id: str, api_key: str = Depends(get_api_key)):
|
||||||
|
api_key_info = await verify_api_key(api_key)
|
||||||
|
if not api_key_info:
|
||||||
|
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||||
|
|
||||||
|
# 从 REDIS_DB (15) 获取任务状态
|
||||||
|
task_info = redis_client.hgetall(f"task:{task_id}")
|
||||||
|
if not task_info:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
|
task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()}
|
||||||
|
|
||||||
|
if task_info['status'] != 'completed':
|
||||||
|
return {"status": task_info['status'], "message": "Task is not completed yet"}
|
||||||
|
|
||||||
|
result_type = task_info['result_type']
|
||||||
|
result_key = task_info['result_key']
|
||||||
|
|
||||||
|
# 根据任务类型选择相应的 Redis 客户端
|
||||||
|
redis_client_map = {
|
||||||
|
'pose': redis_pose_client,
|
||||||
|
'cpm': redis_cpm_client,
|
||||||
|
'yolo': redis_yolo_client,
|
||||||
|
'face': redis_face_client,
|
||||||
|
'fall': redis_fall_client,
|
||||||
|
'mediapipe': redis_mediapipe_client,
|
||||||
|
'qwenvl': redis_qwenvl_client,
|
||||||
|
'compare': redis_compare_client,
|
||||||
|
'qwenvl_analyze': redis_qwenvl_analyze_client,
|
||||||
|
'cpm_analyze': redis_cpm_analyze_client
|
||||||
|
}
|
||||||
|
|
||||||
|
result_redis = redis_client_map.get(result_type)
|
||||||
|
if not result_redis:
|
||||||
|
raise HTTPException(status_code=400, detail="Unsupported result type")
|
||||||
|
|
||||||
|
result = result_redis.hgetall(result_key)
|
||||||
|
if not result:
|
||||||
|
raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found")
|
||||||
|
|
||||||
|
result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
|
||||||
|
|
||||||
|
# 将 result 字段解析为 JSON(如果存在)
|
||||||
|
if 'result' in result:
|
||||||
|
result['result'] = json.loads(result['result'])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"result_type": result_type,
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
|
||||||
|
@v1_app.get("/annotated/{task_id}")
|
||||||
|
async def get_annotated_image(task_id: str, api_key: str = Depends(get_api_key)):
|
||||||
|
api_key_info = await verify_api_key(api_key)
|
||||||
|
if not api_key_info:
|
||||||
|
raise HTTPException(status_code=403, detail="无效的API密钥")
|
||||||
|
|
||||||
|
# 从 REDIS_DB (15) 获取任务信息
|
||||||
|
task_info = redis_client.hgetall(f"task:{task_id}")
|
||||||
|
if not task_info:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
|
task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()}
|
||||||
|
|
||||||
|
if task_info['status'] != 'completed':
|
||||||
|
raise HTTPException(status_code=400, detail="Task is not completed yet")
|
||||||
|
|
||||||
|
result_type = task_info.get('result_type')
|
||||||
|
result_key = task_info.get('result_key')
|
||||||
|
|
||||||
|
if not result_key:
|
||||||
|
raise HTTPException(status_code=404, detail="Result key not found")
|
||||||
|
|
||||||
|
if result_type in ['cpm', 'qwenvl','cpm_analyze', 'qwenvl_analyze']:
|
||||||
|
raise HTTPException(status_code=400, detail="Annotated image not available for this task type")
|
||||||
|
|
||||||
|
# 根据任务类型选择相应的 Redis 客户端
|
||||||
|
redis_client_map = {
|
||||||
|
'pose': redis_pose_client,
|
||||||
|
'yolo': redis_yolo_client,
|
||||||
|
'face': redis_face_client,
|
||||||
|
'fall': redis_fall_client,
|
||||||
|
'mediapipe': redis_mediapipe_client,
|
||||||
|
'compare': redis_compare_client
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
result_redis = redis_client_map.get(result_type)
|
||||||
|
if not result_redis:
|
||||||
|
raise HTTPException(status_code=400, detail="Unsupported result type")
|
||||||
|
|
||||||
|
result = result_redis.hgetall(result_key)
|
||||||
|
if not result:
|
||||||
|
raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found")
|
||||||
|
|
||||||
|
result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()}
|
||||||
|
|
||||||
|
result_file = result.get('result_file')
|
||||||
|
if not result_file:
|
||||||
|
raise HTTPException(status_code=404, detail="Result file not found")
|
||||||
|
|
||||||
|
file_path = os.path.join(RESULT_DIR, result_file)
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise HTTPException(status_code=404, detail="Result image file not found")
|
||||||
|
|
||||||
|
return FileResponse(file_path, media_type="image/png")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8005)
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
fastapi
|
||||||
|
uvicorn
|
||||||
|
python-multipart
|
||||||
|
kafka-python
|
||||||
|
redis
|
||||||
|
pillow
|
||||||
|
decord
|
||||||
|
pydantic
|
||||||
|
requests
|
||||||
|
python-jose[cryptography]
|
||||||
|
passlib[bcrypt]
|
||||||
|
sqlalchemy
|
||||||
@@ -0,0 +1,94 @@
|
|||||||
|
# Kafka 配置
|
||||||
|
KAFKA_BROKER=222.186.10.253:9092
|
||||||
|
KAFKA_ASR_TOPIC=asr
|
||||||
|
KAFKA_CHAT_TOPIC=chat
|
||||||
|
KAFKA_TTS_TOPIC=tts
|
||||||
|
|
||||||
|
|
||||||
|
# Redis 配置
|
||||||
|
REDIS_HOST=150.158.144.159
|
||||||
|
REDIS_PORT=13003
|
||||||
|
REDIS_ASR_DB=12
|
||||||
|
REDIS_CHAT_DB=13
|
||||||
|
REDIS_TTS_DB=14
|
||||||
|
REDIS_PASSWORD=Obscura@2024
|
||||||
|
REDIS_API_DB=2
|
||||||
|
REDIS_API_USAGE_DB=3
|
||||||
|
REDIS_TASK_DB=11
|
||||||
|
REDIS_SESSION_DB=63
|
||||||
|
|
||||||
|
REDIS_SESSION_DB_ZH=63
|
||||||
|
REDIS_SESSION_DB_EN=62
|
||||||
|
REDIS_SESSION_DB_KO=61
|
||||||
|
|
||||||
|
# CORS 配置
|
||||||
|
# ALLOWED_ORIGINS=https://beta.obscura.work
|
||||||
|
|
||||||
|
|
||||||
|
# GPT-SoVITS 配置
|
||||||
|
GPT_MODEL_PATH=GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||||
|
SOVITS_MODEL_PATH=GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||||
|
REF_AUDIO_PATH=sample/woman.wav
|
||||||
|
REF_TEXT_PATH=sample/woman.txt
|
||||||
|
REF_LANGUAGE=中文
|
||||||
|
TARGET_LANGUAGE=多语种混合
|
||||||
|
OUTPUT_PATH=/obscura/task/audio_files
|
||||||
|
|
||||||
|
# VOICE_CONFIGS
|
||||||
|
GIRL_REF_AUDIO=sample/gril.wav
|
||||||
|
GIRL_REF_TEXT=sample/gril.txt
|
||||||
|
|
||||||
|
WOMAN_REF_AUDIO=sample/woman.wav
|
||||||
|
WOMAN_REF_TEXT=sample/woman.txt
|
||||||
|
|
||||||
|
|
||||||
|
MAN_REF_AUDIO=sample/man.wav
|
||||||
|
MAN_REF_TEXT=sample/man.txt
|
||||||
|
|
||||||
|
LEIJUN_REF_AUDIO=sample/leijun.wav
|
||||||
|
LEIJUN_REF_TEXT=sample/leijun.txt
|
||||||
|
|
||||||
|
DUFU_REF_AUDIO=sample/dufu.wav
|
||||||
|
DUFU_REF_TEXT=sample/dufu.txt
|
||||||
|
|
||||||
|
HEJIONG_REF_AUDIO=sample/hejiong.wav
|
||||||
|
HEJIONG_REF_TEXT=sample/hejiong.txt
|
||||||
|
|
||||||
|
MAHUATENG_REF_AUDIO=sample/mahuateng.wav
|
||||||
|
MAHUATENG_REF_TEXT=sample/mahuateng.txt
|
||||||
|
|
||||||
|
LIDAN_REF_AUDIO=sample/lidan.wav
|
||||||
|
LIDAN_REF_TEXT=sample/lidan.txt
|
||||||
|
|
||||||
|
YUHUA_REF_AUDIO=sample/yuhua.wav
|
||||||
|
YUHUA_REF_TEXT=sample/yuhua.txt
|
||||||
|
|
||||||
|
LIUZHENYUN_REF_AUDIO=sample/liuzhenyun.wav
|
||||||
|
LIUZHENYUN_REF_TEXT=sample/liuzhenyun.txt
|
||||||
|
|
||||||
|
DABING_REF_AUDIO=sample/dabing.wav
|
||||||
|
DABING_REF_TEXT=sample/dabing.txt
|
||||||
|
|
||||||
|
LUOXIANG_REF_AUDIO=sample/luoxiang.wav
|
||||||
|
LUOXIANG_REF_TEXT=sample/luoxiang.txt
|
||||||
|
|
||||||
|
XUZHIYUAN_REF_AUDIO=sample/xuzhiyuan.wav
|
||||||
|
XUZHIYUAN_REF_TEXT=sample/xuzhiyuan.txt
|
||||||
|
|
||||||
|
|
||||||
|
REDIS_GIRL_DB = 15
|
||||||
|
REDIS_WOMAN_DB = 16
|
||||||
|
REDIS_MAN_DB = 17
|
||||||
|
REDIS_LEIJUN_DB = 18
|
||||||
|
REDIS_DUFU_DB = 19
|
||||||
|
REDIS_HEJIONG_DB = 20
|
||||||
|
REDIS_MAHUATENG_DB = 21
|
||||||
|
REDIS_LIDAN_DB = 22
|
||||||
|
REDIS_DABING_DB = 23
|
||||||
|
REDIS_LUOXIANG_DB = 24
|
||||||
|
REDIS_XUZHIYUAN_DB = 25
|
||||||
|
REDIS_YUHUA_DB = 26
|
||||||
|
REDIS_LIUZHENYUN_DB = 27
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,406 @@
|
|||||||
|
from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from kafka import KafkaProducer
|
||||||
|
from redis import Redis
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import tempfile
|
||||||
|
import hashlib
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
# 在文件顶部添加这个函数
|
||||||
|
def get_audio_hash(text):
|
||||||
|
return hashlib.md5(text.encode()).hexdigest()
|
||||||
|
|
||||||
|
# 加载 .env 文件
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
v1_chat_app = FastAPI()
|
||||||
|
app.mount("/v1_chat", v1_chat_app)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||||
|
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||||
|
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||||
|
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||||
|
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB'))
|
||||||
|
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
|
||||||
|
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
|
||||||
|
REDIS_API_DB = int(os.getenv('REDIS_API_DB'))
|
||||||
|
REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB'))
|
||||||
|
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||||
|
|
||||||
|
|
||||||
|
# Redis 配置
|
||||||
|
REDIS_GIRL_DB = int(os.getenv('REDIS_GIRL_DB'))
|
||||||
|
REDIS_WOMAN_DB = int(os.getenv('REDIS_WOMAN_DB'))
|
||||||
|
REDIS_MAN_DB = int(os.getenv('REDIS_MAN_DB'))
|
||||||
|
REDIS_LEIJUN_DB = int(os.getenv('REDIS_LEIJUN_DB'))
|
||||||
|
REDIS_DUFU_DB = int(os.getenv('REDIS_DUFU_DB'))
|
||||||
|
REDIS_HEJIONG_DB = int(os.getenv('REDIS_HEJIONG_DB'))
|
||||||
|
REDIS_MAHUATENG_DB = int(os.getenv('REDIS_MAHUATENG_DB'))
|
||||||
|
REDIS_LIDAN_DB = int(os.getenv('REDIS_LIDAN_DB'))
|
||||||
|
REDIS_YUHUA_DB = int(os.getenv('REDIS_YUHUA_DB'))
|
||||||
|
REDIS_LIUZHENYUN_DB = int(os.getenv('REDIS_LIUZHENYUN_DB'))
|
||||||
|
REDIS_DABING_DB = int(os.getenv('REDIS_DABING_DB'))
|
||||||
|
REDIS_LUOXIANG_DB = int(os.getenv('REDIS_LUOXIANG_DB'))
|
||||||
|
REDIS_XUZHIYUAN_DB = int(os.getenv('REDIS_XUZHIYUAN_DB'))
|
||||||
|
|
||||||
|
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||||
|
KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
||||||
|
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
|
||||||
|
|
||||||
|
OUTPUT_PATH= os.getenv('OUTPUT_PATH')
|
||||||
|
|
||||||
|
# 初始化 Kafka Producer
|
||||||
|
producer = KafkaProducer(
|
||||||
|
bootstrap_servers=[KAFKA_BROKER],
|
||||||
|
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化 Redis
|
||||||
|
redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB)
|
||||||
|
redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB)
|
||||||
|
redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB)
|
||||||
|
redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB)
|
||||||
|
redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB)
|
||||||
|
redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB)
|
||||||
|
|
||||||
|
redis_tts_girl = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_GIRL_DB)
|
||||||
|
redis_tts_woman = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_WOMAN_DB)
|
||||||
|
redis_tts_man = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAN_DB)
|
||||||
|
redis_tts_leijun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LEIJUN_DB)
|
||||||
|
redis_tts_dufu = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DUFU_DB)
|
||||||
|
redis_tts_hejiong = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_HEJIONG_DB)
|
||||||
|
redis_tts_mahuateng = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAHUATENG_DB)
|
||||||
|
redis_tts_lidan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIDAN_DB)
|
||||||
|
redis_tts_yuhua = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_YUHUA_DB)
|
||||||
|
redis_tts_liuzhenyun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIUZHENYUN_DB)
|
||||||
|
redis_tts_dabing = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DABING_DB)
|
||||||
|
redis_tts_luoxiang = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LUOXIANG_DB)
|
||||||
|
redis_tts_xuzhiyuan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_XUZHIYUAN_DB)
|
||||||
|
|
||||||
|
# 创建一个音色到对应 Redis 客户端的映射
|
||||||
|
voice_to_redis = {
|
||||||
|
"default": redis_tts_girl,
|
||||||
|
"girl": redis_tts_girl,
|
||||||
|
"woman": redis_tts_woman,
|
||||||
|
"man": redis_tts_man,
|
||||||
|
"leijun": redis_tts_leijun,
|
||||||
|
"dufu": redis_tts_dufu,
|
||||||
|
"hejiong": redis_tts_hejiong,
|
||||||
|
"mahuateng": redis_tts_mahuateng,
|
||||||
|
"lidan": redis_tts_lidan,
|
||||||
|
"yuhua": redis_tts_yuhua,
|
||||||
|
"liuzhenyun": redis_tts_liuzhenyun,
|
||||||
|
"dabing": redis_tts_dabing,
|
||||||
|
"luoxiang": redis_tts_luoxiang,
|
||||||
|
"xuzhiyuan": redis_tts_xuzhiyuan
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 定义API密钥头部
|
||||||
|
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||||
|
|
||||||
|
def get_audio_hash(text):
|
||||||
|
return hashlib.md5(text.encode()).hexdigest()
|
||||||
|
|
||||||
|
# 验证API密钥的函数
|
||||||
|
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||||
|
if api_key and api_key.startswith("Bearer "):
|
||||||
|
key = api_key.split(" ")[1]
|
||||||
|
if key.startswith("obs-"):
|
||||||
|
return key
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="无效的API密钥",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||||
|
redis_key = f"api_key:{api_key}"
|
||||||
|
|
||||||
|
api_key_info = redis_api_client.hgetall(redis_key)
|
||||||
|
|
||||||
|
if not api_key_info:
|
||||||
|
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||||
|
|
||||||
|
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||||
|
|
||||||
|
if api_key_info.get('is_active') != '1':
|
||||||
|
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||||
|
|
||||||
|
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||||
|
if datetime.now(timezone.utc) > expires_at:
|
||||||
|
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||||
|
|
||||||
|
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||||
|
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"api_key": api_key,
|
||||||
|
**api_key_info,
|
||||||
|
**usage_info
|
||||||
|
}
|
||||||
|
|
||||||
|
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||||
|
redis_key = f"api_key:{api_key}"
|
||||||
|
current_time = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
pipe = redis_api_usage_client.pipeline()
|
||||||
|
|
||||||
|
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||||
|
pipe.hset(redis_key, "last_used_at", current_time)
|
||||||
|
|
||||||
|
model_tokens_field = f"{model_name}_tokens_used"
|
||||||
|
model_last_used_field = f"{model_name}_last_used_at"
|
||||||
|
|
||||||
|
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||||
|
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||||
|
|
||||||
|
pipe.execute()
|
||||||
|
|
||||||
|
async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str):
|
||||||
|
api_key = api_key_info['api_key']
|
||||||
|
usage_key = f"api_key:{api_key}"
|
||||||
|
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||||
|
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||||
|
|
||||||
|
if tokens_used + tokens_required > total_tokens:
|
||||||
|
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||||
|
|
||||||
|
# 更新 token 使用量
|
||||||
|
await update_token_usage(api_key, tokens_required, model_name)
|
||||||
|
|
||||||
|
# 发送任务到Kafka
|
||||||
|
producer.send(kafka_topic, task_data)
|
||||||
|
|
||||||
|
# 获取更新后的 token 使用情况
|
||||||
|
updated_api_key_info = await verify_api_key(api_key)
|
||||||
|
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||||
|
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"{model_name.upper()}请求已排队等待处理",
|
||||||
|
"tokens_used": tokens_required,
|
||||||
|
"total_tokens_used": new_tokens_used,
|
||||||
|
f"{model_name}_tokens_used": model_tokens_used,
|
||||||
|
"tokens_remaining": total_tokens - new_tokens_used
|
||||||
|
}
|
||||||
|
|
||||||
|
class TTSRequest(BaseModel):
|
||||||
|
text: str
|
||||||
|
voice: str = Field(..., description="选择的音色")
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
query: str
|
||||||
|
model: str = "qwen2.5:3b"
|
||||||
|
|
||||||
|
|
||||||
|
@v1_chat_app.post("/tts")
|
||||||
|
async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
text_hash = get_audio_hash(request.text)
|
||||||
|
|
||||||
|
# 验证音色选择
|
||||||
|
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
|
||||||
|
if request.voice not in valid_voices:
|
||||||
|
raise HTTPException(status_code=400, detail="无效的音色选择")
|
||||||
|
|
||||||
|
# 如果声音是 'default',则将其视为 'girl'
|
||||||
|
voice = 'girl' if request.voice == 'default' else request.voice
|
||||||
|
|
||||||
|
# 使用对应音色的 Redis 客户端
|
||||||
|
redis_tts = voice_to_redis[request.voice]
|
||||||
|
|
||||||
|
# 检查是否已存在相同内容的音频文件
|
||||||
|
existing_audio_info = redis_tts.get(f"tts:{text_hash}")
|
||||||
|
if existing_audio_info:
|
||||||
|
existing_audio_path = json.loads(existing_audio_info)['path']
|
||||||
|
if os.path.exists(existing_audio_path):
|
||||||
|
return {
|
||||||
|
"message": "TTS请求已完成",
|
||||||
|
"task_id": task_id,
|
||||||
|
"status": "completed",
|
||||||
|
"audio_path": existing_audio_path
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果不存在,创建新的任务
|
||||||
|
task_data = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"text": request.text,
|
||||||
|
"text_hash": text_hash,
|
||||||
|
"voice": request.voice,
|
||||||
|
"status": "queued",
|
||||||
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 存储任务信息到Redis
|
||||||
|
redis_task_client.set(f"task_status:tts:{task_id}", "queued")
|
||||||
|
|
||||||
|
result = await process_request(api_key_info, "tts", 1, task_data, KAFKA_TTS_TOPIC)
|
||||||
|
result["task_id"] = task_id
|
||||||
|
return result
|
||||||
|
@v1_chat_app.post("/asr")
|
||||||
|
async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
UPLOAD_DIR = "/obscura/task/audio_upload"
|
||||||
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||||
|
file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav")
|
||||||
|
|
||||||
|
with open(file_path, "wb") as temp_audio:
|
||||||
|
content = await audio.read()
|
||||||
|
temp_audio.write(content)
|
||||||
|
|
||||||
|
task_data = {
|
||||||
|
'file_path': file_path,
|
||||||
|
'task_id': task_id,
|
||||||
|
'status': 'queued'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 存储任务状态,使用一致的键名格式
|
||||||
|
redis_task_client.set(f"task_status:asr:{task_id}", "queued")
|
||||||
|
|
||||||
|
result = await process_request(api_key_info, "asr", 1, task_data, KAFKA_ASR_TOPIC)
|
||||||
|
result["task_id"] = task_id
|
||||||
|
return result
|
||||||
|
|
||||||
|
@v1_chat_app.post("/chat")
|
||||||
|
async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
task_data = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"session_id": request.session_id,
|
||||||
|
"query": request.query,
|
||||||
|
"model": request.model,
|
||||||
|
"status": "queued",
|
||||||
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 设置任务状态为 "queued"
|
||||||
|
redis_task_client.set(f"chat:{task_id}:status", "queued")
|
||||||
|
|
||||||
|
result = await process_request(api_key_info, "chat", 1, task_data, KAFKA_CHAT_TOPIC)
|
||||||
|
result["task_id"] = task_id
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@v1_chat_app.get("/chat_result/{task_id}")
|
||||||
|
async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
# 从Redis任务数据库获取任务状态
|
||||||
|
task_status = redis_task_client.get(f"chat:{task_id}:status")
|
||||||
|
if task_status:
|
||||||
|
status = task_status.decode('utf-8')
|
||||||
|
if status == "completed":
|
||||||
|
# 从Redis任务数据库获取聊天结果
|
||||||
|
chat_result = redis_task_client.get(f"chat:{task_id}:result")
|
||||||
|
if chat_result:
|
||||||
|
result = json.loads(chat_result)
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
return {"status": status}
|
||||||
|
|
||||||
|
return {"status": "not_found"}
|
||||||
|
|
||||||
|
@v1_chat_app.get("/tts_result/{task_id}")
|
||||||
|
async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
|
||||||
|
if task_status:
|
||||||
|
status = task_status.decode('utf-8')
|
||||||
|
if status == "completed":
|
||||||
|
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
|
||||||
|
if task_info:
|
||||||
|
task_data = json.loads(task_info)
|
||||||
|
text_hash = task_data['text_hash']
|
||||||
|
voice = task_data['voice']
|
||||||
|
# 'default' 和 'girl' 都使用 girl 的 Redis
|
||||||
|
redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice]
|
||||||
|
|
||||||
|
audio_info = redis_tts.get(f"tts:{text_hash}")
|
||||||
|
if audio_info:
|
||||||
|
audio_path = json.loads(audio_info)['path']
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"audio_path": audio_path
|
||||||
|
}
|
||||||
|
return {"status": status}
|
||||||
|
|
||||||
|
return {"status": "not_found"}
|
||||||
|
|
||||||
|
@v1_chat_app.get("/asr_result/{task_id}")
|
||||||
|
async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
# 从Redis任务数据库获取任务状态,使用一致的键名格式
|
||||||
|
task_status = redis_task_client.get(f"task_status:asr:{task_id}")
|
||||||
|
if task_status:
|
||||||
|
status = task_status.decode('utf-8')
|
||||||
|
if status == "completed":
|
||||||
|
# 从Redis ASR结果数据库获取转录结果
|
||||||
|
transcription = redis_asr_client.get(f"asr:{task_id}")
|
||||||
|
if transcription:
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"transcription": transcription.decode('utf-8')
|
||||||
|
}
|
||||||
|
return {"status": status}
|
||||||
|
|
||||||
|
return {"status": "not_found"}
|
||||||
|
|
||||||
|
@v1_chat_app.get("/tts_audio/{task_id}")
|
||||||
|
async def get_tts_audio(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
|
||||||
|
if task_status:
|
||||||
|
status = task_status.decode('utf-8')
|
||||||
|
if status == "completed":
|
||||||
|
# 从任务信息中获取使用的音色
|
||||||
|
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
|
||||||
|
if task_info:
|
||||||
|
task_data = json.loads(task_info)
|
||||||
|
voice = task_data.get('voice', 'girl') # 默认使用 'girl'
|
||||||
|
# 'default' 和 'girl' 都使用 girl 的 Redis
|
||||||
|
redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice]
|
||||||
|
|
||||||
|
# 从对应音色的 Redis 获取音频文件路径
|
||||||
|
audio_info = redis_tts.get(f"tts:{task_data['text_hash']}")
|
||||||
|
if audio_info:
|
||||||
|
audio_path = json.loads(audio_info)['path']
|
||||||
|
if os.path.exists(audio_path):
|
||||||
|
file_name = os.path.basename(audio_path)
|
||||||
|
return FileResponse(audio_path, media_type="audio/wav", filename=file_name)
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=404, detail="音频文件不存在")
|
||||||
|
elif status == "queued" or status == "processing":
|
||||||
|
raise HTTPException(status_code=202, detail="音频文件正在生成中")
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=500, detail="任务处理失败")
|
||||||
|
|
||||||
|
raise HTTPException(status_code=404, detail="任务不存在")
|
||||||
|
@v1_chat_app.get("/getvoice")
|
||||||
|
async def get_available_voices(api_key_info: dict = Depends(verify_api_key)):
|
||||||
|
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
|
||||||
|
return {"available_voices": valid_voices}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8008)
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
fastapi
|
||||||
|
uvicorn
|
||||||
|
pydantic
|
||||||
|
kafka-python
|
||||||
|
redis
|
||||||
|
python-dotenv
|
||||||
|
python-multipart
|
||||||
Reference in New Issue
Block a user