186 lines
5.7 KiB
Python
186 lines
5.7 KiB
Python
from fastapi import FastAPI, File, UploadFile, HTTPException, WebSocket
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import whisper
|
|
import tempfile
|
|
import uvicorn
|
|
from kafka import KafkaProducer, KafkaConsumer
|
|
from starlette.websockets import WebSocketDisconnect
|
|
import json
|
|
import threading
|
|
import os
|
|
import uuid
|
|
import asyncio
|
|
import logging
|
|
import redis
|
|
from dotenv import load_dotenv
|
|
|
|
# 设置要使用的GPU ID
|
|
GPU_ID = 1 # 修改这个值来选择要使用的GPU
|
|
|
|
# 设置CUDA_VISIBLE_DEVICES环境变量
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
|
|
|
|
# 设置日志
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = FastAPI()
|
|
load_dotenv()
|
|
|
|
|
|
# CORS 配置
|
|
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
|
|
|
|
|
# 添加CORS中间件
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=ALLOWED_ORIGINS,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
print("正在加载Whisper模型...")
|
|
model = whisper.load_model("large-v3")
|
|
print("Whisper模型加载完成。")
|
|
|
|
# Kafka配置
|
|
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
|
KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
|
|
|
# Redis配置
|
|
REDIS_HOST = os.getenv('REDIS_HOST')
|
|
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
|
REDIS_DB = int(os.getenv('REDIS_ASR_DB'))
|
|
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
|
|
|
# 创建Redis客户端
|
|
redis_client = redis.Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
db=REDIS_DB,
|
|
password=REDIS_PASSWORD # 添加密码
|
|
)
|
|
|
|
# Kafka生产者
|
|
producer = KafkaProducer(
|
|
bootstrap_servers=[KAFKA_BROKER],
|
|
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
|
)
|
|
|
|
# 存储WebSocket连接的字典
|
|
active_connections = {}
|
|
|
|
@app.websocket("/asr/ws/{client_id}")
|
|
async def websocket_endpoint(websocket: WebSocket, client_id: str):
|
|
await websocket.accept()
|
|
active_connections[client_id] = websocket
|
|
try:
|
|
while True:
|
|
try:
|
|
# 设置接收超时
|
|
data = await asyncio.wait_for(websocket.receive_text(), timeout=30)
|
|
if data == "ping":
|
|
await websocket.send_text("pong")
|
|
else:
|
|
await websocket.send_text(f"收到消息: {data}")
|
|
except asyncio.TimeoutError:
|
|
try:
|
|
# 发送心跳
|
|
await websocket.send_text("heartbeat")
|
|
except WebSocketDisconnect:
|
|
logger.info(f"客户端 {client_id} 断开连接")
|
|
break
|
|
except WebSocketDisconnect:
|
|
logger.info(f"客户端 {client_id} 断开连接")
|
|
except Exception as e:
|
|
logger.error(f"WebSocket错误: {e}")
|
|
finally:
|
|
if client_id in active_connections:
|
|
del active_connections[client_id]
|
|
|
|
|
|
@app.post("/asr")
|
|
async def transcribe(audio: UploadFile = File(...)):
|
|
if not audio:
|
|
raise HTTPException(status_code=400, detail="未提供音频文件")
|
|
|
|
client_id = str(uuid.uuid4())
|
|
|
|
# 生成缓存键
|
|
cache_key = f"asr:{audio.filename}:{client_id}"
|
|
|
|
# 检查缓存
|
|
cached_result = redis_client.get(cache_key)
|
|
if cached_result:
|
|
logger.info(f"缓存命中: {cache_key}")
|
|
return {"message": "从缓存获取转录结果", "transcription": cached_result.decode('utf-8'), "client_id": client_id}
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio:
|
|
content = await audio.read()
|
|
temp_audio.write(content)
|
|
temp_audio.flush()
|
|
|
|
task = {
|
|
'file_path': temp_audio.name,
|
|
'client_id': client_id,
|
|
'cache_key': cache_key
|
|
}
|
|
producer.send(KAFKA_TOPIC, value=task)
|
|
producer.flush()
|
|
|
|
logger.info(f"发送任务到Kafka: {task}")
|
|
return {"message": "音频文件已接收并发送任务进行处理", "client_id": client_id}
|
|
|
|
async def send_transcription(client_id: str, transcription: str):
|
|
if client_id in active_connections:
|
|
websocket = active_connections[client_id]
|
|
await websocket.send_json({"transcription": transcription})
|
|
else:
|
|
logger.warning(f"客户端 {client_id} 的WebSocket连接不存在")
|
|
|
|
def kafka_consumer(consumer_id):
|
|
consumer = KafkaConsumer(
|
|
KAFKA_TOPIC,
|
|
bootstrap_servers=[KAFKA_BROKER],
|
|
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
|
|
group_id='asr_group',
|
|
max_poll_interval_ms=300000
|
|
)
|
|
|
|
for message in consumer:
|
|
try:
|
|
task = message.value
|
|
file_path = task.get('file_path')
|
|
client_id = task.get('client_id')
|
|
cache_key = task.get('cache_key')
|
|
|
|
if not file_path or not client_id or not cache_key:
|
|
logger.error(f"消费者 {consumer_id} 收到无效任务: {task}")
|
|
consumer.commit()
|
|
continue
|
|
|
|
result = model.transcribe(file_path)
|
|
|
|
logger.info(f"消费者 {consumer_id} 处理了文件: {file_path}")
|
|
logger.info(f"转录结果: {result['text']}")
|
|
|
|
# 将结果存入Redis缓存
|
|
redis_client.setex(cache_key, 3600, result['text']) # 缓存1小时
|
|
|
|
asyncio.run(send_transcription(client_id, result['text']))
|
|
|
|
os.remove(file_path)
|
|
consumer.commit()
|
|
except Exception as e:
|
|
logger.error(f"消费者 {consumer_id} 处理消息时发生错误: {str(e)}")
|
|
|
|
def start_consumers(num_consumers=1):
|
|
for i in range(num_consumers):
|
|
consumer_thread = threading.Thread(target=kafka_consumer, args=(i,))
|
|
consumer_thread.start()
|
|
|
|
if __name__ == '__main__':
|
|
start_consumers()
|
|
uvicorn.run(app, host="0.0.0.0", port=6000) |