Files
api/api_chat/before/asr.py
T
2025-01-12 06:15:15 +00:00

115 lines
3.3 KiB
Python

import whisper
import os
import json
import redis
from dotenv import load_dotenv
from kafka import KafkaConsumer
import threading
# 设置要使用的GPU ID
GPU_ID = 1 # 修改这个值来选择要使用的GPU
# 设置CUDA_VISIBLE_DEVICES环境变量
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
# 加载环境变量
load_dotenv()
print("正在加载Whisper模型...")
model = whisper.load_model("large-v3")
print("Whisper模型加载完成。")
# Kafka配置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
# Redis配置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
# 创建Redis客户端
redis_asr_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_ASR_DB,
password=REDIS_PASSWORD
)
redis_task_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_TASK_DB,
password=REDIS_PASSWORD
)
def process_audio(file_path: str, client_id: str, cache_key: str):
try:
# 设置任务状态为 "processing"
redis_task_client.set(f"task_status:{cache_key}", "processing")
result = model.transcribe(file_path)
transcription = result['text']
print(f"处理了文件: {file_path}")
print(f"转录结果: {transcription}")
# 将结果存入Redis缓存
redis_asr_client.setex(cache_key, 3600, transcription) # 缓存1小时
# 发布结果到Redis频道
result_data = {
'client_id': client_id,
'transcription': transcription
}
redis_asr_client.publish('asr_results', json.dumps(result_data))
# 设置任务状态为 "completed"
redis_task_client.set(f"task_status:{cache_key}", "completed")
# 清理临时文件
os.remove(file_path)
except Exception as e:
print(f"处理音频文件时发生错误: {str(e)}")
# 设置任务状态为 "error"
redis_task_client.set(f"task_status:{cache_key}", "error")
def kafka_consumer():
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
group_id='asr_group',
auto_offset_reset='earliest',
enable_auto_commit=True
)
print(f"ASR消费者已启动")
for message in consumer:
try:
task = message.value
file_path = task.get('file_path')
task_id = task.get('task_id')
status = task.get('status')
if not file_path or not task_id or status != 'queued':
print(f"收到无效任务: {task}")
continue
cache_key = f"asr:{task_id}"
client_id = task_id # 使用task_id作为client_id
print(f"开始处理任务: {cache_key}")
process_audio(file_path, client_id, cache_key)
print(f"完成处理任务: {cache_key}")
except Exception as e:
print(f"处理消息时发生错误: {str(e)}")
if __name__ == "__main__":
print("启动Kafka消费者处理ASR请求...")
kafka_consumer()