319 lines
11 KiB
Python
319 lines
11 KiB
Python
from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.security import APIKeyHeader
|
|
from pydantic import BaseModel
|
|
from kafka import KafkaProducer
|
|
from redis import Redis
|
|
import os
|
|
import json
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from dotenv import load_dotenv
|
|
import tempfile
|
|
import hashlib
|
|
import asyncio
|
|
|
|
|
|
# 加载 .env 文件
|
|
load_dotenv()
|
|
|
|
app = FastAPI()
|
|
v1_chat_app = FastAPI()
|
|
app.mount("/v1_chat", v1_chat_app)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# 配置
|
|
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
|
REDIS_HOST = os.getenv('REDIS_HOST')
|
|
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
|
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
|
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB'))
|
|
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
|
|
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
|
|
REDIS_API_DB = int(os.getenv('REDIS_API_DB'))
|
|
REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB'))
|
|
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
|
|
|
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
|
KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
|
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
|
|
|
|
# 初始化 Kafka Producer
|
|
producer = KafkaProducer(
|
|
bootstrap_servers=[KAFKA_BROKER],
|
|
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
|
)
|
|
|
|
# 初始化 Redis
|
|
redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB)
|
|
redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB)
|
|
redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB)
|
|
redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB)
|
|
redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB)
|
|
redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB)
|
|
|
|
# 定义API密钥头部
|
|
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
|
|
|
def get_audio_hash(text):
|
|
return hashlib.md5(text.encode()).hexdigest()
|
|
|
|
# 验证API密钥的函数
|
|
async def get_api_key(api_key: str = Security(api_key_header)):
|
|
if api_key and api_key.startswith("Bearer "):
|
|
key = api_key.split(" ")[1]
|
|
if key.startswith("obs-"):
|
|
return key
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="无效的API密钥",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
|
redis_key = f"api_key:{api_key}"
|
|
|
|
api_key_info = redis_api_client.hgetall(redis_key)
|
|
|
|
if not api_key_info:
|
|
raise HTTPException(status_code=401, detail="无效的API密钥")
|
|
|
|
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
|
|
|
if api_key_info.get('is_active') != '1':
|
|
raise HTTPException(status_code=401, detail="API密钥已停用")
|
|
|
|
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
|
if datetime.now(timezone.utc) > expires_at:
|
|
raise HTTPException(status_code=401, detail="API密钥已过期")
|
|
|
|
usage_info = redis_api_usage_client.hgetall(redis_key)
|
|
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
|
|
|
return {
|
|
"api_key": api_key,
|
|
**api_key_info,
|
|
**usage_info
|
|
}
|
|
|
|
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
|
redis_key = f"api_key:{api_key}"
|
|
current_time = datetime.now(timezone.utc).isoformat()
|
|
|
|
pipe = redis_api_usage_client.pipeline()
|
|
|
|
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
|
pipe.hset(redis_key, "last_used_at", current_time)
|
|
|
|
model_tokens_field = f"{model_name}_tokens_used"
|
|
model_last_used_field = f"{model_name}_last_used_at"
|
|
|
|
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
|
pipe.hset(redis_key, model_last_used_field, current_time)
|
|
|
|
pipe.execute()
|
|
|
|
async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str):
|
|
api_key = api_key_info['api_key']
|
|
usage_key = f"api_key:{api_key}"
|
|
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
|
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
|
|
|
if tokens_used + tokens_required > total_tokens:
|
|
raise HTTPException(status_code=403, detail="Token 余额不足")
|
|
|
|
# 更新 token 使用量
|
|
await update_token_usage(api_key, tokens_required, model_name)
|
|
|
|
# 发送任务到Kafka
|
|
producer.send(kafka_topic, task_data)
|
|
|
|
# 获取更新后的 token 使用情况
|
|
updated_api_key_info = await verify_api_key(api_key)
|
|
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
|
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
|
|
|
return {
|
|
"message": f"{model_name.upper()}请求已排队等待处理",
|
|
"tokens_used": tokens_required,
|
|
"total_tokens_used": new_tokens_used,
|
|
f"{model_name}_tokens_used": model_tokens_used,
|
|
"tokens_remaining": total_tokens - new_tokens_used
|
|
}
|
|
|
|
class TTSRequest(BaseModel):
|
|
text: str
|
|
|
|
class ChatRequest(BaseModel):
|
|
session_id: str
|
|
query: str
|
|
model: str = "qwen2.5:3b"
|
|
|
|
|
|
# 添加WebSocket连接管理
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
self.active_connections = {}
|
|
|
|
async def connect(self, websocket: WebSocket, client_id: str):
|
|
await websocket.accept()
|
|
self.active_connections[client_id] = websocket
|
|
|
|
def disconnect(self, client_id: str):
|
|
self.active_connections.pop(client_id, None)
|
|
|
|
async def send_message(self, message: str, client_id: str):
|
|
if client_id in self.active_connections:
|
|
await self.active_connections[client_id].send_text(message)
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
@v1_chat_app.websocket("/ws/{client_id}")
|
|
async def websocket_endpoint(websocket: WebSocket, client_id: str):
|
|
await manager.connect(websocket, client_id)
|
|
try:
|
|
while True:
|
|
await websocket.receive_text()
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(client_id)
|
|
|
|
# 修改TTS请求处理函数
|
|
@v1_chat_app.post("/tts")
|
|
async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)):
|
|
task_id = str(uuid.uuid4())
|
|
task_data = {
|
|
"task_id": task_id,
|
|
"text": request.text,
|
|
"status": "queued",
|
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
|
|
redis_task_client.set(f"task_status:{task_id}", "queued")
|
|
|
|
result = await process_request(api_key_info, "tts", 100, task_data, KAFKA_TTS_TOPIC)
|
|
result["task_id"] = task_id
|
|
|
|
# 将任务ID存储到Redis,以便后续WebSocket通信使用
|
|
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
|
|
|
|
return result
|
|
|
|
# 修改ASR请求处理函数
|
|
@v1_chat_app.post("/asr")
|
|
async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
|
task_id = str(uuid.uuid4())
|
|
|
|
UPLOAD_DIR = "/obscura/task/audio_upload"
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|
file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav")
|
|
|
|
with open(file_path, "wb") as temp_audio:
|
|
content = await audio.read()
|
|
temp_audio.write(content)
|
|
|
|
task_data = {
|
|
'file_path': file_path,
|
|
'task_id': task_id,
|
|
'status': 'queued'
|
|
}
|
|
|
|
redis_task_client.set(f"task_status:{task_id}", "queued")
|
|
|
|
result = await process_request(api_key_info, "asr", 100, task_data, KAFKA_ASR_TOPIC)
|
|
result["task_id"] = task_id
|
|
|
|
# 将任务ID存储到Redis,以便后续WebSocket通信使用
|
|
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
|
|
|
|
return result
|
|
|
|
# 修改聊天请求处理函数
|
|
@v1_chat_app.post("/chat")
|
|
async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)):
|
|
task_id = str(uuid.uuid4())
|
|
task_data = {
|
|
"task_id": task_id,
|
|
"session_id": request.session_id,
|
|
"query": request.query,
|
|
"model": request.model,
|
|
"status": "queued",
|
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
|
|
redis_task_client.set(f"task_status:{task_id}", "queued")
|
|
|
|
result = await process_request(api_key_info, "chat", 100, task_data, KAFKA_CHAT_TOPIC)
|
|
result["task_id"] = task_id
|
|
|
|
# 将任务ID存储到Redis,以便后续WebSocket通信使用
|
|
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
|
|
|
|
return result
|
|
|
|
@v1_chat_app.get("/chat_result/{task_id}")
|
|
async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
|
# 从Redis任务数据库获取任务状态
|
|
task_status = redis_task_client.get(f"task_status:{task_id}")
|
|
if task_status:
|
|
status = task_status.decode('utf-8')
|
|
if status == "completed":
|
|
# 从Redis聊天结果数据库获取聊天结果
|
|
chat_result = redis_chat_client.get(task_id)
|
|
if chat_result:
|
|
result = json.loads(chat_result)
|
|
return {
|
|
"status": "completed",
|
|
"history": result # 直接返回整个历史记录
|
|
}
|
|
return {"status": status}
|
|
|
|
return {"status": "not_found"}
|
|
|
|
@v1_chat_app.get("/tts_result/{task_id}")
|
|
async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
|
# 从Redis任务数据库获取任务状态
|
|
task_status = redis_task_client.get(f"task_status:{task_id}")
|
|
if task_status:
|
|
status = task_status.decode('utf-8')
|
|
if status == "completed":
|
|
# 从Redis TTS结果数据库获取音频文件路径
|
|
audio_info = redis_tts_client.get(task_id)
|
|
if audio_info:
|
|
audio_path = json.loads(audio_info)['path']
|
|
return {
|
|
"status": "completed",
|
|
"audio_path": audio_path
|
|
}
|
|
return {"status": status}
|
|
|
|
return {"status": "not_found"}
|
|
|
|
@v1_chat_app.get("/asr_result/{task_id}")
|
|
async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
|
# 从Redis任务数据库获取任务状态
|
|
task_status = redis_task_client.get(f"task_status:{task_id}")
|
|
if task_status:
|
|
status = task_status.decode('utf-8')
|
|
if status == "completed":
|
|
# 从Redis ASR结果数据库获取转录结果
|
|
transcription = redis_asr_client.get(task_id)
|
|
if transcription:
|
|
return {
|
|
"status": "completed",
|
|
"transcription": transcription.decode('utf-8')
|
|
}
|
|
return {"status": status}
|
|
|
|
return {"status": "not_found"}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8008) |