Files
api/api_chat/before/whisper_api.py
T
2025-01-12 06:15:15 +00:00

186 lines
5.7 KiB
Python

from fastapi import FastAPI, File, UploadFile, HTTPException, WebSocket
from fastapi.middleware.cors import CORSMiddleware
import whisper
import tempfile
import uvicorn
from kafka import KafkaProducer, KafkaConsumer
from starlette.websockets import WebSocketDisconnect
import json
import threading
import os
import uuid
import asyncio
import logging
import redis
from dotenv import load_dotenv
# 设置要使用的GPU ID
GPU_ID = 1 # 修改这个值来选择要使用的GPU
# 设置CUDA_VISIBLE_DEVICES环境变量
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
load_dotenv()
# CORS 配置
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
print("正在加载Whisper模型...")
model = whisper.load_model("large-v3")
print("Whisper模型加载完成。")
# Kafka配置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
# Redis配置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_DB = int(os.getenv('REDIS_ASR_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
# 创建Redis客户端
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_DB,
password=REDIS_PASSWORD # 添加密码
)
# Kafka生产者
producer = KafkaProducer(
bootstrap_servers=[KAFKA_BROKER],
value_serializer=lambda v: json.dumps(v).encode('utf-8')
)
# 存储WebSocket连接的字典
active_connections = {}
@app.websocket("/asr/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
await websocket.accept()
active_connections[client_id] = websocket
try:
while True:
try:
# 设置接收超时
data = await asyncio.wait_for(websocket.receive_text(), timeout=30)
if data == "ping":
await websocket.send_text("pong")
else:
await websocket.send_text(f"收到消息: {data}")
except asyncio.TimeoutError:
try:
# 发送心跳
await websocket.send_text("heartbeat")
except WebSocketDisconnect:
logger.info(f"客户端 {client_id} 断开连接")
break
except WebSocketDisconnect:
logger.info(f"客户端 {client_id} 断开连接")
except Exception as e:
logger.error(f"WebSocket错误: {e}")
finally:
if client_id in active_connections:
del active_connections[client_id]
@app.post("/asr")
async def transcribe(audio: UploadFile = File(...)):
if not audio:
raise HTTPException(status_code=400, detail="未提供音频文件")
client_id = str(uuid.uuid4())
# 生成缓存键
cache_key = f"asr:{audio.filename}:{client_id}"
# 检查缓存
cached_result = redis_client.get(cache_key)
if cached_result:
logger.info(f"缓存命中: {cache_key}")
return {"message": "从缓存获取转录结果", "transcription": cached_result.decode('utf-8'), "client_id": client_id}
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio:
content = await audio.read()
temp_audio.write(content)
temp_audio.flush()
task = {
'file_path': temp_audio.name,
'client_id': client_id,
'cache_key': cache_key
}
producer.send(KAFKA_TOPIC, value=task)
producer.flush()
logger.info(f"发送任务到Kafka: {task}")
return {"message": "音频文件已接收并发送任务进行处理", "client_id": client_id}
async def send_transcription(client_id: str, transcription: str):
if client_id in active_connections:
websocket = active_connections[client_id]
await websocket.send_json({"transcription": transcription})
else:
logger.warning(f"客户端 {client_id} 的WebSocket连接不存在")
def kafka_consumer(consumer_id):
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
group_id='asr_group',
max_poll_interval_ms=300000
)
for message in consumer:
try:
task = message.value
file_path = task.get('file_path')
client_id = task.get('client_id')
cache_key = task.get('cache_key')
if not file_path or not client_id or not cache_key:
logger.error(f"消费者 {consumer_id} 收到无效任务: {task}")
consumer.commit()
continue
result = model.transcribe(file_path)
logger.info(f"消费者 {consumer_id} 处理了文件: {file_path}")
logger.info(f"转录结果: {result['text']}")
# 将结果存入Redis缓存
redis_client.setex(cache_key, 3600, result['text']) # 缓存1小时
asyncio.run(send_transcription(client_id, result['text']))
os.remove(file_path)
consumer.commit()
except Exception as e:
logger.error(f"消费者 {consumer_id} 处理消息时发生错误: {str(e)}")
def start_consumers(num_consumers=1):
for i in range(num_consumers):
consumer_thread = threading.Thread(target=kafka_consumer, args=(i,))
consumer_thread.start()
if __name__ == '__main__':
start_consumers()
uvicorn.run(app, host="0.0.0.0", port=6000)