Files
api/producer_chat/producer_chat.py
2025-04-10 09:45:41 +00:00

404 lines
16 KiB
Python
Executable File

from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader
from fastapi.responses import FileResponse
from pydantic import BaseModel
from kafka import KafkaProducer
from redis import Redis
import os
import json
import uuid
from datetime import datetime, timezone
from dotenv import load_dotenv
import hashlib
from pydantic import BaseModel, Field
# 在文件顶部添加这个函数
def get_audio_hash(text):
return hashlib.md5(text.encode()).hexdigest()
# 加载 .env 文件
load_dotenv()
app = FastAPI()
v1_chat_app = FastAPI()
app.mount("/v1_chat", v1_chat_app)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 配置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB'))
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
REDIS_API_DB = int(os.getenv('REDIS_API_DB'))
REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB'))
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
# Redis 配置
REDIS_GIRL_DB = int(os.getenv('REDIS_GIRL_DB'))
REDIS_WOMAN_DB = int(os.getenv('REDIS_WOMAN_DB'))
REDIS_MAN_DB = int(os.getenv('REDIS_MAN_DB'))
REDIS_LEIJUN_DB = int(os.getenv('REDIS_LEIJUN_DB'))
REDIS_DUFU_DB = int(os.getenv('REDIS_DUFU_DB'))
REDIS_HEJIONG_DB = int(os.getenv('REDIS_HEJIONG_DB'))
REDIS_MAHUATENG_DB = int(os.getenv('REDIS_MAHUATENG_DB'))
REDIS_LIDAN_DB = int(os.getenv('REDIS_LIDAN_DB'))
REDIS_YUHUA_DB = int(os.getenv('REDIS_YUHUA_DB'))
REDIS_LIUZHENYUN_DB = int(os.getenv('REDIS_LIUZHENYUN_DB'))
REDIS_DABING_DB = int(os.getenv('REDIS_DABING_DB'))
REDIS_LUOXIANG_DB = int(os.getenv('REDIS_LUOXIANG_DB'))
REDIS_XUZHIYUAN_DB = int(os.getenv('REDIS_XUZHIYUAN_DB'))
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
UPLOAD_DIR = os.getenv('UPLOAD_DIR')
# 初始化 Kafka Producer
producer = KafkaProducer(
bootstrap_servers=[KAFKA_BROKER],
value_serializer=lambda v: json.dumps(v).encode('utf-8')
)
# 初始化 Redis
redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB)
redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB)
redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB)
redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB)
redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB)
redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB)
redis_tts_girl = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_GIRL_DB)
redis_tts_woman = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_WOMAN_DB)
redis_tts_man = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAN_DB)
redis_tts_leijun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LEIJUN_DB)
redis_tts_dufu = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DUFU_DB)
redis_tts_hejiong = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_HEJIONG_DB)
redis_tts_mahuateng = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAHUATENG_DB)
redis_tts_lidan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIDAN_DB)
redis_tts_yuhua = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_YUHUA_DB)
redis_tts_liuzhenyun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIUZHENYUN_DB)
redis_tts_dabing = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DABING_DB)
redis_tts_luoxiang = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LUOXIANG_DB)
redis_tts_xuzhiyuan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_XUZHIYUAN_DB)
# 创建一个音色到对应 Redis 客户端的映射
voice_to_redis = {
"default": redis_tts_girl,
"girl": redis_tts_girl,
"woman": redis_tts_woman,
"man": redis_tts_man,
"leijun": redis_tts_leijun,
"dufu": redis_tts_dufu,
"hejiong": redis_tts_hejiong,
"mahuateng": redis_tts_mahuateng,
"lidan": redis_tts_lidan,
"yuhua": redis_tts_yuhua,
"liuzhenyun": redis_tts_liuzhenyun,
"dabing": redis_tts_dabing,
"luoxiang": redis_tts_luoxiang,
"xuzhiyuan": redis_tts_xuzhiyuan
}
# 定义API密钥头部
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
def get_audio_hash(text):
return hashlib.md5(text.encode()).hexdigest()
# 验证API密钥的函数
async def get_api_key(api_key: str = Security(api_key_header)):
if api_key and api_key.startswith("Bearer "):
key = api_key.split(" ")[1]
if key.startswith("obs-"):
return key
raise HTTPException(
status_code=401,
detail="无效的API密钥",
headers={"WWW-Authenticate": "Bearer"},
)
async def verify_api_key(api_key: str = Depends(get_api_key)):
redis_key = f"api_key:{api_key}"
api_key_info = redis_api_client.hgetall(redis_key)
if not api_key_info:
raise HTTPException(status_code=401, detail="无效的API密钥")
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
if api_key_info.get('is_active') != '1':
raise HTTPException(status_code=401, detail="API密钥已停用")
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
if datetime.now(timezone.utc) > expires_at:
raise HTTPException(status_code=401, detail="API密钥已过期")
usage_info = redis_api_usage_client.hgetall(redis_key)
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
return {
"api_key": api_key,
**api_key_info,
**usage_info
}
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
redis_key = f"api_key:{api_key}"
current_time = datetime.now(timezone.utc).isoformat()
pipe = redis_api_usage_client.pipeline()
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
pipe.hset(redis_key, "last_used_at", current_time)
model_tokens_field = f"{model_name}_tokens_used"
model_last_used_field = f"{model_name}_last_used_at"
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
pipe.hset(redis_key, model_last_used_field, current_time)
pipe.execute()
async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str):
api_key = api_key_info['api_key']
usage_key = f"api_key:{api_key}"
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
if tokens_used + tokens_required > total_tokens:
raise HTTPException(status_code=403, detail="Token 余额不足")
# 更新 token 使用量
await update_token_usage(api_key, tokens_required, model_name)
# 发送任务到Kafka
producer.send(kafka_topic, task_data)
# 获取更新后的 token 使用情况
updated_api_key_info = await verify_api_key(api_key)
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
return {
"message": f"{model_name.upper()}请求已排队等待处理",
"tokens_used": tokens_required,
"total_tokens_used": new_tokens_used,
f"{model_name}_tokens_used": model_tokens_used,
"tokens_remaining": total_tokens - new_tokens_used
}
class TTSRequest(BaseModel):
text: str
voice: str = Field(..., description="选择的音色")
class ChatRequest(BaseModel):
session_id: str
query: str
model: str = "qwen2.5:3b"
@v1_chat_app.post("/tts")
async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
text_hash = get_audio_hash(request.text)
# 验证音色选择
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
if request.voice not in valid_voices:
raise HTTPException(status_code=400, detail="无效的音色选择")
# 如果声音是 'default',则将其视为 'girl'
voice = 'girl' if request.voice == 'default' else request.voice
# 使用对应音色的 Redis 客户端
redis_tts = voice_to_redis[request.voice]
# 检查是否已存在相同内容的音频文件
existing_audio_info = redis_tts.get(f"tts:{text_hash}")
if existing_audio_info:
existing_audio_path = json.loads(existing_audio_info)['path']
if os.path.exists(existing_audio_path):
return {
"message": "TTS请求已完成",
"task_id": task_id,
"status": "completed",
"audio_path": existing_audio_path
}
# 如果不存在,创建新的任务
task_data = {
"task_id": task_id,
"text": request.text,
"text_hash": text_hash,
"voice": request.voice,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
}
# 存储任务信息到Redis
redis_task_client.set(f"task_status:tts:{task_id}", "queued")
result = await process_request(api_key_info, "tts", 1, task_data, KAFKA_TTS_TOPIC)
result["task_id"] = task_id
return result
@v1_chat_app.post("/asr")
async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
os.makedirs(UPLOAD_DIR, exist_ok=True)
file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav")
with open(file_path, "wb") as temp_audio:
content = await audio.read()
temp_audio.write(content)
task_data = {
'file_path': file_path,
'task_id': task_id,
'status': 'queued'
}
# 存储任务状态,使用一致的键名格式
redis_task_client.set(f"task_status:asr:{task_id}", "queued")
result = await process_request(api_key_info, "asr", 1, task_data, KAFKA_ASR_TOPIC)
result["task_id"] = task_id
return result
@v1_chat_app.post("/chat")
async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
task_data = {
"task_id": task_id,
"session_id": request.session_id,
"query": request.query,
"model": request.model,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
}
# 设置任务状态为 "queued"
redis_task_client.set(f"chat:{task_id}:status", "queued")
result = await process_request(api_key_info, "chat", 1, task_data, KAFKA_CHAT_TOPIC)
result["task_id"] = task_id
return result
@v1_chat_app.get("/chat_result/{task_id}")
async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
# 从Redis任务数据库获取任务状态
task_status = redis_task_client.get(f"chat:{task_id}:status")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从Redis任务数据库获取聊天结果
chat_result = redis_task_client.get(f"chat:{task_id}:result")
if chat_result:
result = json.loads(chat_result)
return {
"status": "completed",
"result": result
}
return {"status": status}
return {"status": "not_found"}
@v1_chat_app.get("/tts_result/{task_id}")
async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
if task_info:
task_data = json.loads(task_info)
text_hash = task_data['text_hash']
voice = task_data['voice']
# 'default' 和 'girl' 都使用 girl 的 Redis
redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice]
audio_info = redis_tts.get(f"tts:{text_hash}")
if audio_info:
audio_path = json.loads(audio_info)['path']
return {
"status": "completed",
"audio_path": audio_path
}
return {"status": status}
return {"status": "not_found"}
@v1_chat_app.get("/asr_result/{task_id}")
async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
# 从Redis任务数据库获取任务状态,使用一致的键名格式
task_status = redis_task_client.get(f"task_status:asr:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从Redis ASR结果数据库获取转录结果
transcription = redis_asr_client.get(f"asr:{task_id}")
if transcription:
return {
"status": "completed",
"transcription": transcription.decode('utf-8')
}
return {"status": status}
return {"status": "not_found"}
@v1_chat_app.get("/tts_audio/{task_id}")
async def get_tts_audio(task_id: str, api_key_info: dict = Depends(verify_api_key)):
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从任务信息中获取使用的音色
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
if task_info:
task_data = json.loads(task_info)
voice = task_data.get('voice', 'girl') # 默认使用 'girl'
# 'default' 和 'girl' 都使用 girl 的 Redis
redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice]
# 从对应音色的 Redis 获取音频文件路径
audio_info = redis_tts.get(f"tts:{task_data['text_hash']}")
if audio_info:
audio_path = json.loads(audio_info)['path']
if os.path.exists(audio_path):
file_name = os.path.basename(audio_path)
return FileResponse(audio_path, media_type="audio/wav", filename=file_name)
else:
raise HTTPException(status_code=404, detail="音频文件不存在")
elif status == "queued" or status == "processing":
raise HTTPException(status_code=202, detail="音频文件正在生成中")
else:
raise HTTPException(status_code=500, detail="任务处理失败")
raise HTTPException(status_code=404, detail="任务不存在")
@v1_chat_app.get("/getvoice")
async def get_available_voices(api_key_info: dict = Depends(verify_api_key)):
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
return {"available_voices": valid_voices}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8008)