diff --git a/.gitignore b/.gitignore index 92b7a02..48c1709 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ api_chat/runtime/* api_chat/docs/* !api_chat/docs/.gitkeep + diff --git a/README.md b/README.md new file mode 100644 index 0000000..d4b14e1 --- /dev/null +++ b/README.md @@ -0,0 +1,140 @@ +# API集合 + +该项目为所有API集合,集成了视觉分析、聊天对话和语音处理等功能。 + +## 项目结构 + + ``` + API/ + ├── api/ # 视觉分析和处理模块 + │ ├── producer/ # 主程序入口,生产者,分配任务 + │ ├── cpm_analyze.py # CPM_OCR模型分析 + │ ├── qwenvl_analyze.py # QwenVL_OCR模型分析 + │ ├── cpm_scene.py # CPM_场景模型分析 + │ ├── qwenvl_scene.py # QwenVL_场景模型分析 + │ ├── compare.py # 人脸对比模型 + │ ├── yolo.py # YOLO目标检测 + │ ├── face.py # 人脸检测 + │ ├── fall.py # 跌倒检测 + │ ├── pose.py # 姿态估计 + │ └── media.py # 媒体处理 + ├── api_chat/ # 聊天和语音处理模块 + │ ├── producer_chat/ # 聊天生产者 + │ ├── chat.py # 聊天功能 + │ ├── tts.py # 文字转语音 + │ ├── asr.py # 语音识别 + │ ├── GPT_SoVITS/ # GPT_SoVITS模型集成, + │ ├── sample/ # OpenBMB模型——学习音色,音色+文本内容, + │ ├── tools/ # GPT_SoVITS模型——工具函数 + │ ├── runtime/ # GPT_SoVITS模型——运行时函数 + │ ├── docs/ # GPT_SoVITS模型——文档 + │ ├── TEMP/ # OpenBMB模型临时文件夹, + │ └── before/ # 历史代码,可以忽略 + ├── api_history/ # api历史代码,可以忽略 + ├── chat_history/ # api_chat历史代码,可以忽略 + └── api_old/ # api历史代码,可以忽略 + ``` + +## 主要功能 + +### 视觉分析模块 (api/) + - 目标检测和跟踪 + - 人脸识别 + - 人脸对比 + - 姿态估计 + - 跌倒检测 + - 场景理解(基于CPM和QwenVL模型) + +### 聊天对话模块 (api_chat/) + - 文本对话功能 + - 语音识别 (ASR): 通过Whisper模型 + - 文字转语音 (TTS): 通过GPT_SoVITS模型 + - 多模型支持(通过Ollama) + +## 使用说明 +### API 部分 http://dev.obscura.work/v1 + 1. producer 目录 # 生产者,分配任务 + 2. 服务器:222.186.10.253:8005 + 3. kafka 配置:222.186.10.253:9092 + topic分配: + - yolo: "yolo" + - pose: "pose" + - qwenvl: "qwenvl" + - qwenvl_analyze: "qwenvl_analyze" + - cpm: "cpm" + - cpm_analyze: "cpm_analyze" + - fall: "fall" + - face: "face" + - mediapipe: "mediapipe" + - compare: "compare" + + 4. redis 配置:150.158.144.159:13003 + db分配: + - 4: "yolo" + - 5: "pose" + - 9: "qwenvl" + - 32: "qwenvl_analyze" + - 8: "cpm" + - 31: "cpm_analyze" + - 6: "fall" + - 7: "face" + - 10: "mediapipe" + - 30: "compare" + + 5. 模型配置: + - YOLO = "/obscura/models/yolov8x.pt" + - POSE = "/obscura/models/yolov8x-pose.pt" + - QWEN = "/obscura/models/qwen/Qwen2-VL-2B-Instruct" + - FALL = "/obscura/models/yolov8n-fall.pt" + - FACE = "/obscura/models/yolov8n-face.pt" + - MEDIAPIPE = "/obscura/models/face_landmarker.task" + - CPM(ollama) = "https://ffgregevrdcfyhtnhyudvr.myfastools.com/api/generate" + 6. 上传文件及结果保存目录: + - UPLOAD_DIR = "/obscura/task/upload" + - RESULT_DIR = "/obscura/task/result" + +### API_Chat 部分 http://dev.obscura.work/v1_chat + 1. producer_chat 目录 # 聊天生产者 + 2. 服务器:222.186.10.253:8008 + 3. kafka 配置:222.186.10.253:9092 + topic分配: + - asr: "asr" + - chat: "chat" + - tts: "tts" + 4. redis 配置:150.158.144.159:13003 + db分配: + - 2: "api key" + - 3: "api使用情况" + - 11: "task 任务记录" + - 12: "asr 记录" + - 13: "chat 记录" + - 14: "tts 记录" + #语言 + - 63: "session_zh 中文" + - 62: "session_en 英文" + - 61: "session_ko 韩语" + #音色 + - 15: "girl" + - 16: "woman" + - 17: "man" + - 18: "leijun" + - 19: "dufu" + - 20: "hejiong" + - 21: "mahuateng" + - 22: "lidan" + - 23: "dabing" + - 24: "luoxiang" + - 25: "xuzhiyuan" + - 26: "yuhua" + - 27: "liuzhenyun" + 5. 音频文件保存目录: + - OUTPUT_PATH=/obscura/task/audio_files # 音频文件保存目录 + +## 注意事项 +- 注意redis的db分配 +- 注意kafka的topic分配 +- 注意producer的config.py环境配置 +- 注意producer_chat的.env环境配置 +- 确保模型权重文件已正确配置 +- 检查API密钥和环境变量设置 +- 注意资源使用和性能优化 \ No newline at end of file diff --git a/api/ollama_proxy.py b/api/ollama_proxy.py deleted file mode 100644 index 7b412a4..0000000 --- a/api/ollama_proxy.py +++ /dev/null @@ -1,97 +0,0 @@ -# 将本地Ollama API完全反向代理 -from fastapi import FastAPI, Request -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse -import httpx -OLLAMA_URL = "http://127.0.0.1:11434" - -app = FastAPI() - -# 添加CORS中间件 -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], # 允许所有来源 - allow_credentials=True, - allow_methods=["*"], # 允许所有HTTP方法 - allow_headers=["*"], # 允许所有HTTP头 -) - -# 创建异步HTTP客户端 -async_client = httpx.AsyncClient() - -@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"]) -async def proxy_to_ollama(request: Request, path: str): - if request.method == "OPTIONS": - # 处理预检请求 - headers = { - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", - "Access-Control-Allow-Headers": "*", - } - return StreamingResponse(content=iter([]), headers=headers) - - target_url = f"{OLLAMA_URL}/{path}" - - # 获取请求体 - body = await request.body() - - # 获取请求头 - headers = dict(request.headers) - headers.pop("host", None) - - try: - # 将请求转换为Python请求 - python_request = { - "method": request.method, - "url": target_url, - "headers": headers, - "data": body - } - - # 使用Python请求发送到Ollama API - async with async_client.stream(**python_request) as response: - # 返回Ollama API的流式响应,并添加CORS头 - response_headers = dict(response.headers) - response_headers["Access-Control-Allow-Origin"] = "*" - return StreamingResponse( - response.aiter_raw(), - status_code=response.status_code, - headers=response_headers - ) - except httpx.RequestError as e: - return {"error": f"请求Ollama API时发生错误: {str(e)}"}, 500 - except httpx.StreamClosed: - # 处理流关闭异常 - print("流已关闭,客户端可能已断开连接") - return {"error": "流已关闭,客户端可能已断开连接"}, 499 - -@app.on_event("shutdown") -async def shutdown_event(): - await async_client.aclose() - -if __name__ == "__main__": - import uvicorn - import requests - import json - - # 测试Ollama API - test_url = "http://localhost:11434/api/generate" - test_data = { - "model": "llama3.1", - "prompt": "Why is the sky blue?", - "stream": False - } - - try: - response = requests.post(test_url, json=test_data) - if response.status_code == 200: - print("Ollama API 测试成功:") - print(json.dumps(response.json(), indent=2)) - else: - print(f"Ollama API 测试失败,状态码: {response.status_code}") - print(response.text) - except requests.RequestException as e: - print(f"Ollama API 测试出错: {str(e)}") - - # 启动FastAPI应用 - uvicorn.run(app, host="0.0.0.0", port=8001) diff --git a/api/producer/config.py b/api/producer/config.py new file mode 100644 index 0000000..049a8c2 --- /dev/null +++ b/api/producer/config.py @@ -0,0 +1,80 @@ +# config.py + +import os + +# Kafka配置 +KAFKA_BROKER = "222.186.10.253:9092" +KAFKA_GROUP_ID_PREFIX = "group" + +# Redis配置 +REDIS_HOST = "150.158.144.159" +REDIS_PORT = 13003 +REDIS_PASSWORD = "Obscura@2024" +MAIN_REDIS_DB = 0 +REDIS_API_DB = 2 +REDIS_API_USAGE_DB = 3 +# 目录配置 +UPLOAD_DIR = "/obscura/task/upload" +RESULT_DIR = "/obscura/task/result" + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 模型配置 +YOLO_MODEL_PATH = "/obscura/models/yolov8x.pt" +POSE_MODEL_PATH = "/obscura/models/yolov8x-pose.pt" +QWEN_MODEL_PATH = "/obscura/models/qwen/Qwen2-VL-2B-Instruct" +FALL_MODEL_PATH = "/obscura/models/yolov8n-fall.pt" +FACE_MODEL_PATH = "/obscura/models/yolov8n-face.pt" +MEDIAPIPE_MODEL_PATH = "/obscura/models/face_landmarker.task" +# Ollama配置 +OLLAMA_URL = "https://ffgregevrdcfyhtnhyudvr.myfastools.com/api/generate" + +# 各个worker的配置 +WORKER_CONFIGS = { + "yolo": { + "kafka_topic": "yolo", + "redis_db": 4, + }, + "pose": { + "kafka_topic": "pose", + "redis_db": 5, + }, + "qwenvl": { + "kafka_topic": "qwenvl", + "redis_db": 9, + }, + "qwenvl_analyze": { + "kafka_topic": "qwenvl_analyze", + "redis_db": 32, + }, + "cpm": { + "kafka_topic": "cpm", + "redis_db": 8, + }, + "cpm_analyze": { + "kafka_topic": "cpm_analyze", + "redis_db": 31, + }, + "fall": { + "kafka_topic": "fall", + "redis_db": 6, + }, + "face": { + "kafka_topic": "face", + "redis_db": 7, + }, + "mediapipe": { + "kafka_topic": "mediapipe", + "redis_db": 10, + }, + "compare": { + "kafka_topic": "compare", + "redis_db": 30, + } +} + +# GPU设置 +CUDA_DEVICE_0 = "cuda:0" +CUDA_DEVICE_1 = "cuda:1" diff --git a/api/producer/producer.py b/api/producer/producer.py new file mode 100644 index 0000000..e6c1067 --- /dev/null +++ b/api/producer/producer.py @@ -0,0 +1,442 @@ +# main.py +from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Security +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi.security import APIKeyHeader +from kafka import KafkaProducer +from redis import Redis +import os +import json +import uuid +from datetime import datetime, timedelta, timezone +import string +from decord import VideoReader +from PIL import Image +from fastapi.responses import FileResponse +import logging +from config import * + +app = FastAPI() +v1_app = FastAPI() +app.mount("/v1", v1_app) + + +# CORS设置 +# ALLOWED_ORIGINS = ['https://beta.obscura.work'] + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +KAFKA_BROKER = KAFKA_BROKER +REDIS_HOST = REDIS_HOST +REDIS_PORT = REDIS_PORT +REDIS_PASSWORD = REDIS_PASSWORD +REDIS_DB = MAIN_REDIS_DB +REDIS_API_DB = REDIS_API_DB +REDIS_API_USAGE_DB = REDIS_API_USAGE_DB +UPLOAD_DIR = UPLOAD_DIR +RESULT_DIR = RESULT_DIR +MAX_FILE_AGE = timedelta(hours=1) + +# 确保目录存在 +os.makedirs(UPLOAD_DIR, exist_ok=True) +os.makedirs(RESULT_DIR, exist_ok=True) + +# 定义支持的任务类型 +KAFKA_TOPICS = { + 'pose': 'pose', + 'mediapipe': 'mediapipe', + 'qwenvl': 'qwenvl', + 'yolo': 'yolo', + 'fall': 'fall', + 'face': 'face', + 'cpm': 'cpm', + 'compare': 'compare', + 'qwenvl_analyze': 'qwenvl_analyze', + 'cpm_analyze': 'cpm_analyze' +} + +TASK_TYPES = list(KAFKA_TOPICS.keys()) + + +# 初始化 Kafka Producer +producer = KafkaProducer( + bootstrap_servers=[KAFKA_BROKER], + value_serializer=lambda v: json.dumps(v).encode('utf-8') +) + +# 初始化 Redis +redis_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_DB +) + +redis_api_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_DB +) + +redis_api_usage_client = Redis( + host=REDIS_HOST, + port=REDIS_PORT, + password=REDIS_PASSWORD, + db=REDIS_API_USAGE_DB +) +redis_pose_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['pose']['redis_db']) +redis_cpm_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm']['redis_db']) +redis_yolo_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['yolo']['redis_db']) +redis_face_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['face']['redis_db']) +redis_fall_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['fall']['redis_db']) +redis_mediapipe_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['mediapipe']['redis_db']) +redis_qwenvl_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl']['redis_db']) +redis_compare_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['compare']['redis_db']) +redis_qwenvl_analyze_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['qwenvl_analyze']['redis_db']) +redis_cpm_analyze_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=WORKER_CONFIGS['cpm_analyze']['redis_db']) + +@v1_app.get('/favicon.ico', include_in_schema=False) +async def favicon(): + file_name = "favicon.ico" + file_path = os.path.join(app.root_path, "static", file_name) + if os.path.isfile(file_path): + return FileResponse(file_path) + else: + return {"message": "Favicon not found"}, 404 + +# 定义API密钥头部 +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) + +# 定义 base62 字符集 +BASE62 = string.digits + string.ascii_letters + +# 验证API密钥的函数 +async def get_api_key(api_key: str = Security(api_key_header)): + if api_key and api_key.startswith("Bearer "): + key = api_key.split(" ")[1] + if key.startswith("obs-"): + return key + raise HTTPException( + status_code=401, + detail="无效的API密钥", + headers={"WWW-Authenticate": "Bearer"}, + ) + +async def verify_api_key(api_key: str = Depends(get_api_key)): + logging.info(f"验证API密钥: {api_key}") + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + logging.warning(f"API密钥不存在: {api_key}") + raise HTTPException(status_code=401, detail="无效的API密钥") + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + logging.warning(f"API密钥已停用: {api_key}") + raise HTTPException(status_code=401, detail="API密钥已停用") + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + logging.warning(f"API密钥已过期: {api_key}") + raise HTTPException(status_code=401, detail="API密钥已过期") + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + logging.info(f"API密钥验证成功: {api_key}") + return { + "api_key": api_key, + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + # 更新总的token使用量 + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + # 更新特定模型的token使用量 + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +def calculate_tokens(file_path: str, file_type: str) -> int: + base_tokens = 0 + + try: + file_size = os.path.getsize(file_path) # 获取文件大小(字节) + + # 基础token:每MB文件大小消耗10个token + base_tokens = int((file_size / (1024 * 1024)) * 0.1) + + if file_type == "image": + img = Image.open(file_path) + width, height = img.size + pixel_count = width * height + + # 图片token:每100个像素额外消耗5个token + image_tokens = int((pixel_count / 100000000) * 0.1) + + base_tokens += image_tokens + + elif file_type == "video": + vr = VideoReader(file_path) + fps = vr.get_avg_fps() + frame_count = len(vr) + width, height = vr[0].shape[1], vr[0].shape[0] + + pixel_count = width * height * frame_count + duration = frame_count / fps # 视频时长(秒) + + # 视频token:每100万像素每秒额外消耗1个token + video_tokens = int((pixel_count / 100000000) * (duration / 60)) + + base_tokens += video_tokens + + return max(1, base_tokens) # 确保至少返回1个token + except Exception as e: + print(f"计算token时出错: {str(e)}") + return 1 # 出错时返回默认值1 + + + +async def upload_file(file: UploadFile, task_type: str, api_key_info: dict): + if task_type not in KAFKA_TOPICS: + raise HTTPException(status_code=400, detail="不支持的任务类型") + + content = await file.read() + file_extension = os.path.splitext(file.filename)[1].lower() + new_filename = f"{uuid.uuid4()}{file_extension}" + + file_path = os.path.join(UPLOAD_DIR, new_filename) + with open(file_path, "wb") as f: + f.write(content) + + # 计算 token + file_type = "image" if file_extension in ['.jpg', '.jpeg', '.png'] else "video" + tokens_required = calculate_tokens(file_path, file_type) + + # 检查并更新 token 使用量 + api_key = api_key_info['api_key'] + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新 token 使用量 + await update_token_usage(api_key, tokens_required, task_type) + + # 创建任务记录 + task_id = str(uuid.uuid4()) + task_data = { + "task_id": task_id, + "filename": new_filename, + "file_type": file_type, + "task_type": task_type, + "status": "queued", + "created_at": datetime.now(timezone.utc).isoformat(), + } + + # 存储任务信息到Redis + redis_client.set(f"task:{task_id}", json.dumps(task_data)) + logging.info(f"任务信息已存储到Redis: {task_id}") + + # 发送任务到对应的Kafka主题 + kafka_topic = KAFKA_TOPICS[task_type] + producer.send(kafka_topic, task_data) + logging.info(f"任务已发送到Kafka主题: {kafka_topic}") + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{task_type}_tokens_used", 0)) + + response_data = { + "message": "文件已上传并排队等待处理", + "task_id": task_id, + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{task_type}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + } + logging.info(f"上传文件完成: {task_id}") + return JSONResponse(content=response_data) + +# 为每个任务类型创建单独的端点 +@v1_app.post("/pose") +async def upload_pose(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + logging.info(f"收到 /pose端点的请求") + return await upload_file(file, task_type="pose", api_key_info=api_key_info) + +@v1_app.post("/cpm") +async def upload_cpm(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="cpm", api_key_info=api_key_info) + +@v1_app.post("/qwenvl") +async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="qwenvl", api_key_info=api_key_info) + +@v1_app.post("/qwenvl_analyze") +async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="qwenvl_analyze", api_key_info=api_key_info) + +@v1_app.post("/cpm_analyze") +async def upload_qwenvl(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="cpm_analyze", api_key_info=api_key_info) + + +@v1_app.post("/yolo") +async def upload_yolo(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="yolo", api_key_info=api_key_info) + +@v1_app.post("/fall") +async def upload_fall(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="fall", api_key_info=api_key_info) + +@v1_app.post("/face") +async def upload_face(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="face", api_key_info=api_key_info) + +@v1_app.post("/mediapipe") +async def upload_mediapipe(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="mediapipe", api_key_info=api_key_info) + +@v1_app.post("/compare") +async def upload_compare(file: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + return await upload_file(file, task_type="compare", api_key_info=api_key_info) + + +@v1_app.get("/result/{task_id}") +async def get_result(task_id: str, api_key: str = Depends(get_api_key)): + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + # 从 REDIS_DB (15) 获取任务状态 + task_info = redis_client.hgetall(f"task:{task_id}") + if not task_info: + raise HTTPException(status_code=404, detail="Task not found") + + task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()} + + if task_info['status'] != 'completed': + return {"status": task_info['status'], "message": "Task is not completed yet"} + + result_type = task_info['result_type'] + result_key = task_info['result_key'] + + # 根据任务类型选择相应的 Redis 客户端 + redis_client_map = { + 'pose': redis_pose_client, + 'cpm': redis_cpm_client, + 'yolo': redis_yolo_client, + 'face': redis_face_client, + 'fall': redis_fall_client, + 'mediapipe': redis_mediapipe_client, + 'qwenvl': redis_qwenvl_client, + 'compare': redis_compare_client, + 'qwenvl_analyze': redis_qwenvl_analyze_client, + 'cpm_analyze': redis_cpm_analyze_client + } + + result_redis = redis_client_map.get(result_type) + if not result_redis: + raise HTTPException(status_code=400, detail="Unsupported result type") + + result = result_redis.hgetall(result_key) + if not result: + raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found") + + result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()} + + # 将 result 字段解析为 JSON(如果存在) + if 'result' in result: + result['result'] = json.loads(result['result']) + + return { + "status": "completed", + "result_type": result_type, + "result": result + } + +@v1_app.get("/annotated/{task_id}") +async def get_annotated_image(task_id: str, api_key: str = Depends(get_api_key)): + api_key_info = await verify_api_key(api_key) + if not api_key_info: + raise HTTPException(status_code=403, detail="无效的API密钥") + + # 从 REDIS_DB (15) 获取任务信息 + task_info = redis_client.hgetall(f"task:{task_id}") + if not task_info: + raise HTTPException(status_code=404, detail="Task not found") + + task_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in task_info.items()} + + if task_info['status'] != 'completed': + raise HTTPException(status_code=400, detail="Task is not completed yet") + + result_type = task_info.get('result_type') + result_key = task_info.get('result_key') + + if not result_key: + raise HTTPException(status_code=404, detail="Result key not found") + + if result_type in ['cpm', 'qwenvl','cpm_analyze', 'qwenvl_analyze']: + raise HTTPException(status_code=400, detail="Annotated image not available for this task type") + + # 根据任务类型选择相应的 Redis 客户端 + redis_client_map = { + 'pose': redis_pose_client, + 'yolo': redis_yolo_client, + 'face': redis_face_client, + 'fall': redis_fall_client, + 'mediapipe': redis_mediapipe_client, + 'compare': redis_compare_client + + } + + result_redis = redis_client_map.get(result_type) + if not result_redis: + raise HTTPException(status_code=400, detail="Unsupported result type") + + result = result_redis.hgetall(result_key) + if not result: + raise HTTPException(status_code=404, detail=f"{result_type.upper()} result not found") + + result = {k.decode('utf-8'): v.decode('utf-8') for k, v in result.items()} + + result_file = result.get('result_file') + if not result_file: + raise HTTPException(status_code=404, detail="Result file not found") + + file_path = os.path.join(RESULT_DIR, result_file) + if not os.path.exists(file_path): + raise HTTPException(status_code=404, detail="Result image file not found") + + return FileResponse(file_path, media_type="image/png") + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8005) \ No newline at end of file diff --git a/api/producer/requirements.txt b/api/producer/requirements.txt new file mode 100644 index 0000000..06e2a4c --- /dev/null +++ b/api/producer/requirements.txt @@ -0,0 +1,12 @@ +fastapi +uvicorn +python-multipart +kafka-python +redis +pillow +decord +pydantic +requests +python-jose[cryptography] +passlib[bcrypt] +sqlalchemy \ No newline at end of file diff --git a/api_chat/mp4_to_wav.py b/api_chat/before/mp4_to_wav.py similarity index 100% rename from api_chat/mp4_to_wav.py rename to api_chat/before/mp4_to_wav.py diff --git a/api_chat/ollamas.py b/api_chat/before/ollamas.py similarity index 100% rename from api_chat/ollamas.py rename to api_chat/before/ollamas.py diff --git a/api_chat/wav_to_text.py b/api_chat/before/wav_to_text.py similarity index 100% rename from api_chat/wav_to_text.py rename to api_chat/before/wav_to_text.py diff --git a/api_chat/weight.json b/api_chat/before/weight.json similarity index 100% rename from api_chat/weight.json rename to api_chat/before/weight.json diff --git a/api_chat/producer_chat/.env b/api_chat/producer_chat/.env new file mode 100644 index 0000000..b6516d9 --- /dev/null +++ b/api_chat/producer_chat/.env @@ -0,0 +1,94 @@ +# Kafka 配置 +KAFKA_BROKER=222.186.10.253:9092 +KAFKA_ASR_TOPIC=asr +KAFKA_CHAT_TOPIC=chat +KAFKA_TTS_TOPIC=tts + + +# Redis 配置 +REDIS_HOST=150.158.144.159 +REDIS_PORT=13003 +REDIS_ASR_DB=12 +REDIS_CHAT_DB=13 +REDIS_TTS_DB=14 +REDIS_PASSWORD=Obscura@2024 +REDIS_API_DB=2 +REDIS_API_USAGE_DB=3 +REDIS_TASK_DB=11 +REDIS_SESSION_DB=63 + +REDIS_SESSION_DB_ZH=63 +REDIS_SESSION_DB_EN=62 +REDIS_SESSION_DB_KO=61 + +# CORS 配置 +# ALLOWED_ORIGINS=https://beta.obscura.work + + +# GPT-SoVITS 配置 +GPT_MODEL_PATH=GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt +SOVITS_MODEL_PATH=GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth +REF_AUDIO_PATH=sample/woman.wav +REF_TEXT_PATH=sample/woman.txt +REF_LANGUAGE=中文 +TARGET_LANGUAGE=多语种混合 +OUTPUT_PATH=/obscura/task/audio_files + +# VOICE_CONFIGS +GIRL_REF_AUDIO=sample/gril.wav +GIRL_REF_TEXT=sample/gril.txt + +WOMAN_REF_AUDIO=sample/woman.wav +WOMAN_REF_TEXT=sample/woman.txt + + +MAN_REF_AUDIO=sample/man.wav +MAN_REF_TEXT=sample/man.txt + +LEIJUN_REF_AUDIO=sample/leijun.wav +LEIJUN_REF_TEXT=sample/leijun.txt + +DUFU_REF_AUDIO=sample/dufu.wav +DUFU_REF_TEXT=sample/dufu.txt + +HEJIONG_REF_AUDIO=sample/hejiong.wav +HEJIONG_REF_TEXT=sample/hejiong.txt + +MAHUATENG_REF_AUDIO=sample/mahuateng.wav +MAHUATENG_REF_TEXT=sample/mahuateng.txt + +LIDAN_REF_AUDIO=sample/lidan.wav +LIDAN_REF_TEXT=sample/lidan.txt + +YUHUA_REF_AUDIO=sample/yuhua.wav +YUHUA_REF_TEXT=sample/yuhua.txt + +LIUZHENYUN_REF_AUDIO=sample/liuzhenyun.wav +LIUZHENYUN_REF_TEXT=sample/liuzhenyun.txt + +DABING_REF_AUDIO=sample/dabing.wav +DABING_REF_TEXT=sample/dabing.txt + +LUOXIANG_REF_AUDIO=sample/luoxiang.wav +LUOXIANG_REF_TEXT=sample/luoxiang.txt + +XUZHIYUAN_REF_AUDIO=sample/xuzhiyuan.wav +XUZHIYUAN_REF_TEXT=sample/xuzhiyuan.txt + + +REDIS_GIRL_DB = 15 +REDIS_WOMAN_DB = 16 +REDIS_MAN_DB = 17 +REDIS_LEIJUN_DB = 18 +REDIS_DUFU_DB = 19 +REDIS_HEJIONG_DB = 20 +REDIS_MAHUATENG_DB = 21 +REDIS_LIDAN_DB = 22 +REDIS_DABING_DB = 23 +REDIS_LUOXIANG_DB = 24 +REDIS_XUZHIYUAN_DB = 25 +REDIS_YUHUA_DB = 26 +REDIS_LIUZHENYUN_DB = 27 + + + diff --git a/api_chat/producer_chat/producer_chat.py b/api_chat/producer_chat/producer_chat.py new file mode 100644 index 0000000..75a70d2 --- /dev/null +++ b/api_chat/producer_chat/producer_chat.py @@ -0,0 +1,406 @@ +from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile +from fastapi.middleware.cors import CORSMiddleware +from fastapi.security import APIKeyHeader +from fastapi.responses import FileResponse +from pydantic import BaseModel +from kafka import KafkaProducer +from redis import Redis +import os +import json +import uuid +from datetime import datetime, timezone +from dotenv import load_dotenv +import tempfile +import hashlib +from pydantic import BaseModel, Field + +# 在文件顶部添加这个函数 +def get_audio_hash(text): + return hashlib.md5(text.encode()).hexdigest() + +# 加载 .env 文件 +load_dotenv() + +app = FastAPI() +v1_chat_app = FastAPI() +app.mount("/v1_chat", v1_chat_app) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 配置 +KAFKA_BROKER = os.getenv('KAFKA_BROKER') +REDIS_HOST = os.getenv('REDIS_HOST') +REDIS_PORT = int(os.getenv('REDIS_PORT')) +REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') +REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB')) +REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB')) +REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB')) +REDIS_API_DB = int(os.getenv('REDIS_API_DB')) +REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB')) +REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) + + +# Redis 配置 +REDIS_GIRL_DB = int(os.getenv('REDIS_GIRL_DB')) +REDIS_WOMAN_DB = int(os.getenv('REDIS_WOMAN_DB')) +REDIS_MAN_DB = int(os.getenv('REDIS_MAN_DB')) +REDIS_LEIJUN_DB = int(os.getenv('REDIS_LEIJUN_DB')) +REDIS_DUFU_DB = int(os.getenv('REDIS_DUFU_DB')) +REDIS_HEJIONG_DB = int(os.getenv('REDIS_HEJIONG_DB')) +REDIS_MAHUATENG_DB = int(os.getenv('REDIS_MAHUATENG_DB')) +REDIS_LIDAN_DB = int(os.getenv('REDIS_LIDAN_DB')) +REDIS_YUHUA_DB = int(os.getenv('REDIS_YUHUA_DB')) +REDIS_LIUZHENYUN_DB = int(os.getenv('REDIS_LIUZHENYUN_DB')) +REDIS_DABING_DB = int(os.getenv('REDIS_DABING_DB')) +REDIS_LUOXIANG_DB = int(os.getenv('REDIS_LUOXIANG_DB')) +REDIS_XUZHIYUAN_DB = int(os.getenv('REDIS_XUZHIYUAN_DB')) + +KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC') +KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC') +KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC') + +OUTPUT_PATH= os.getenv('OUTPUT_PATH') + +# 初始化 Kafka Producer +producer = KafkaProducer( + bootstrap_servers=[KAFKA_BROKER], + value_serializer=lambda v: json.dumps(v).encode('utf-8') +) + +# 初始化 Redis +redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB) +redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB) +redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB) +redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB) +redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB) +redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB) + +redis_tts_girl = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_GIRL_DB) +redis_tts_woman = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_WOMAN_DB) +redis_tts_man = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAN_DB) +redis_tts_leijun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LEIJUN_DB) +redis_tts_dufu = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DUFU_DB) +redis_tts_hejiong = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_HEJIONG_DB) +redis_tts_mahuateng = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAHUATENG_DB) +redis_tts_lidan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIDAN_DB) +redis_tts_yuhua = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_YUHUA_DB) +redis_tts_liuzhenyun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIUZHENYUN_DB) +redis_tts_dabing = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DABING_DB) +redis_tts_luoxiang = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LUOXIANG_DB) +redis_tts_xuzhiyuan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_XUZHIYUAN_DB) + +# 创建一个音色到对应 Redis 客户端的映射 +voice_to_redis = { + "default": redis_tts_girl, + "girl": redis_tts_girl, + "woman": redis_tts_woman, + "man": redis_tts_man, + "leijun": redis_tts_leijun, + "dufu": redis_tts_dufu, + "hejiong": redis_tts_hejiong, + "mahuateng": redis_tts_mahuateng, + "lidan": redis_tts_lidan, + "yuhua": redis_tts_yuhua, + "liuzhenyun": redis_tts_liuzhenyun, + "dabing": redis_tts_dabing, + "luoxiang": redis_tts_luoxiang, + "xuzhiyuan": redis_tts_xuzhiyuan +} + + +# 定义API密钥头部 +api_key_header = APIKeyHeader(name="Authorization", auto_error=False) + +def get_audio_hash(text): + return hashlib.md5(text.encode()).hexdigest() + +# 验证API密钥的函数 +async def get_api_key(api_key: str = Security(api_key_header)): + if api_key and api_key.startswith("Bearer "): + key = api_key.split(" ")[1] + if key.startswith("obs-"): + return key + raise HTTPException( + status_code=401, + detail="无效的API密钥", + headers={"WWW-Authenticate": "Bearer"}, + ) + +async def verify_api_key(api_key: str = Depends(get_api_key)): + redis_key = f"api_key:{api_key}" + + api_key_info = redis_api_client.hgetall(redis_key) + + if not api_key_info: + raise HTTPException(status_code=401, detail="无效的API密钥") + + api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()} + + if api_key_info.get('is_active') != '1': + raise HTTPException(status_code=401, detail="API密钥已停用") + + expires_at = datetime.fromisoformat(api_key_info.get('expires_at')) + if datetime.now(timezone.utc) > expires_at: + raise HTTPException(status_code=401, detail="API密钥已过期") + + usage_info = redis_api_usage_client.hgetall(redis_key) + usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()} + + return { + "api_key": api_key, + **api_key_info, + **usage_info + } + +async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str): + redis_key = f"api_key:{api_key}" + current_time = datetime.now(timezone.utc).isoformat() + + pipe = redis_api_usage_client.pipeline() + + pipe.hincrby(redis_key, "tokens_used", new_tokens_used) + pipe.hset(redis_key, "last_used_at", current_time) + + model_tokens_field = f"{model_name}_tokens_used" + model_last_used_field = f"{model_name}_last_used_at" + + pipe.hincrby(redis_key, model_tokens_field, new_tokens_used) + pipe.hset(redis_key, model_last_used_field, current_time) + + pipe.execute() + +async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str): + api_key = api_key_info['api_key'] + usage_key = f"api_key:{api_key}" + total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0) + tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0) + + if tokens_used + tokens_required > total_tokens: + raise HTTPException(status_code=403, detail="Token 余额不足") + + # 更新 token 使用量 + await update_token_usage(api_key, tokens_required, model_name) + + # 发送任务到Kafka + producer.send(kafka_topic, task_data) + + # 获取更新后的 token 使用情况 + updated_api_key_info = await verify_api_key(api_key) + new_tokens_used = int(updated_api_key_info.get("tokens_used", 0)) + model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0)) + + return { + "message": f"{model_name.upper()}请求已排队等待处理", + "tokens_used": tokens_required, + "total_tokens_used": new_tokens_used, + f"{model_name}_tokens_used": model_tokens_used, + "tokens_remaining": total_tokens - new_tokens_used + } + +class TTSRequest(BaseModel): + text: str + voice: str = Field(..., description="选择的音色") + +class ChatRequest(BaseModel): + session_id: str + query: str + model: str = "qwen2.5:3b" + + +@v1_chat_app.post("/tts") +async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)): + task_id = str(uuid.uuid4()) + text_hash = get_audio_hash(request.text) + + # 验证音色选择 + valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"] + if request.voice not in valid_voices: + raise HTTPException(status_code=400, detail="无效的音色选择") + + # 如果声音是 'default',则将其视为 'girl' + voice = 'girl' if request.voice == 'default' else request.voice + + # 使用对应音色的 Redis 客户端 + redis_tts = voice_to_redis[request.voice] + + # 检查是否已存在相同内容的音频文件 + existing_audio_info = redis_tts.get(f"tts:{text_hash}") + if existing_audio_info: + existing_audio_path = json.loads(existing_audio_info)['path'] + if os.path.exists(existing_audio_path): + return { + "message": "TTS请求已完成", + "task_id": task_id, + "status": "completed", + "audio_path": existing_audio_path + } + + # 如果不存在,创建新的任务 + task_data = { + "task_id": task_id, + "text": request.text, + "text_hash": text_hash, + "voice": request.voice, + "status": "queued", + "created_at": datetime.now(timezone.utc).isoformat(), + } + + # 存储任务信息到Redis + redis_task_client.set(f"task_status:tts:{task_id}", "queued") + + result = await process_request(api_key_info, "tts", 1, task_data, KAFKA_TTS_TOPIC) + result["task_id"] = task_id + return result +@v1_chat_app.post("/asr") +async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)): + task_id = str(uuid.uuid4()) + + UPLOAD_DIR = "/obscura/task/audio_upload" + os.makedirs(UPLOAD_DIR, exist_ok=True) + file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav") + + with open(file_path, "wb") as temp_audio: + content = await audio.read() + temp_audio.write(content) + + task_data = { + 'file_path': file_path, + 'task_id': task_id, + 'status': 'queued' + } + + # 存储任务状态,使用一致的键名格式 + redis_task_client.set(f"task_status:asr:{task_id}", "queued") + + result = await process_request(api_key_info, "asr", 1, task_data, KAFKA_ASR_TOPIC) + result["task_id"] = task_id + return result + +@v1_chat_app.post("/chat") +async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)): + task_id = str(uuid.uuid4()) + task_data = { + "task_id": task_id, + "session_id": request.session_id, + "query": request.query, + "model": request.model, + "status": "queued", + "created_at": datetime.now(timezone.utc).isoformat(), + } + + # 设置任务状态为 "queued" + redis_task_client.set(f"chat:{task_id}:status", "queued") + + result = await process_request(api_key_info, "chat", 1, task_data, KAFKA_CHAT_TOPIC) + result["task_id"] = task_id + return result + + +@v1_chat_app.get("/chat_result/{task_id}") +async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)): + # 从Redis任务数据库获取任务状态 + task_status = redis_task_client.get(f"chat:{task_id}:status") + if task_status: + status = task_status.decode('utf-8') + if status == "completed": + # 从Redis任务数据库获取聊天结果 + chat_result = redis_task_client.get(f"chat:{task_id}:result") + if chat_result: + result = json.loads(chat_result) + return { + "status": "completed", + "result": result + } + return {"status": status} + + return {"status": "not_found"} + +@v1_chat_app.get("/tts_result/{task_id}") +async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)): + task_status = redis_task_client.get(f"task_status:tts:{task_id}") + if task_status: + status = task_status.decode('utf-8') + if status == "completed": + task_info = redis_task_client.get(f"task_info:tts:{task_id}") + if task_info: + task_data = json.loads(task_info) + text_hash = task_data['text_hash'] + voice = task_data['voice'] + # 'default' 和 'girl' 都使用 girl 的 Redis + redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice] + + audio_info = redis_tts.get(f"tts:{text_hash}") + if audio_info: + audio_path = json.loads(audio_info)['path'] + return { + "status": "completed", + "audio_path": audio_path + } + return {"status": status} + + return {"status": "not_found"} + +@v1_chat_app.get("/asr_result/{task_id}") +async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)): + # 从Redis任务数据库获取任务状态,使用一致的键名格式 + task_status = redis_task_client.get(f"task_status:asr:{task_id}") + if task_status: + status = task_status.decode('utf-8') + if status == "completed": + # 从Redis ASR结果数据库获取转录结果 + transcription = redis_asr_client.get(f"asr:{task_id}") + if transcription: + return { + "status": "completed", + "transcription": transcription.decode('utf-8') + } + return {"status": status} + + return {"status": "not_found"} + +@v1_chat_app.get("/tts_audio/{task_id}") +async def get_tts_audio(task_id: str, api_key_info: dict = Depends(verify_api_key)): + task_status = redis_task_client.get(f"task_status:tts:{task_id}") + if task_status: + status = task_status.decode('utf-8') + if status == "completed": + # 从任务信息中获取使用的音色 + task_info = redis_task_client.get(f"task_info:tts:{task_id}") + if task_info: + task_data = json.loads(task_info) + voice = task_data.get('voice', 'girl') # 默认使用 'girl' + # 'default' 和 'girl' 都使用 girl 的 Redis + redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice] + + # 从对应音色的 Redis 获取音频文件路径 + audio_info = redis_tts.get(f"tts:{task_data['text_hash']}") + if audio_info: + audio_path = json.loads(audio_info)['path'] + if os.path.exists(audio_path): + file_name = os.path.basename(audio_path) + return FileResponse(audio_path, media_type="audio/wav", filename=file_name) + else: + raise HTTPException(status_code=404, detail="音频文件不存在") + elif status == "queued" or status == "processing": + raise HTTPException(status_code=202, detail="音频文件正在生成中") + else: + raise HTTPException(status_code=500, detail="任务处理失败") + + raise HTTPException(status_code=404, detail="任务不存在") +@v1_chat_app.get("/getvoice") +async def get_available_voices(api_key_info: dict = Depends(verify_api_key)): + valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"] + return {"available_voices": valid_voices} + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8008) + + diff --git a/api_chat/producer_chat/requirements.txt b/api_chat/producer_chat/requirements.txt new file mode 100644 index 0000000..a0f9b40 --- /dev/null +++ b/api_chat/producer_chat/requirements.txt @@ -0,0 +1,7 @@ +fastapi +uvicorn +pydantic +kafka-python +redis +python-dotenv +python-multipart \ No newline at end of file diff --git a/api/main.py b/api_history/main.py similarity index 100% rename from api/main.py rename to api_history/main.py diff --git a/api/producer.py b/api_old/producer.py similarity index 100% rename from api/producer.py rename to api_old/producer.py