115 lines
3.3 KiB
Python
115 lines
3.3 KiB
Python
import whisper
|
|
import os
|
|
import json
|
|
import redis
|
|
from dotenv import load_dotenv
|
|
from kafka import KafkaConsumer
|
|
import threading
|
|
|
|
# 设置要使用的GPU ID
|
|
GPU_ID = 1 # 修改这个值来选择要使用的GPU
|
|
|
|
# 设置CUDA_VISIBLE_DEVICES环境变量
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
|
|
|
|
# 加载环境变量
|
|
load_dotenv()
|
|
|
|
print("正在加载Whisper模型...")
|
|
model = whisper.load_model("large-v3")
|
|
print("Whisper模型加载完成。")
|
|
|
|
# Kafka配置
|
|
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
|
KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
|
|
|
# Redis配置
|
|
REDIS_HOST = os.getenv('REDIS_HOST')
|
|
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
|
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
|
|
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
|
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
|
|
|
# 创建Redis客户端
|
|
redis_asr_client = redis.Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
db=REDIS_ASR_DB,
|
|
password=REDIS_PASSWORD
|
|
)
|
|
|
|
redis_task_client = redis.Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
db=REDIS_TASK_DB,
|
|
password=REDIS_PASSWORD
|
|
)
|
|
|
|
def process_audio(file_path: str, client_id: str, cache_key: str):
|
|
try:
|
|
# 设置任务状态为 "processing"
|
|
redis_task_client.set(f"task_status:{cache_key}", "processing")
|
|
|
|
result = model.transcribe(file_path)
|
|
transcription = result['text']
|
|
|
|
print(f"处理了文件: {file_path}")
|
|
print(f"转录结果: {transcription}")
|
|
|
|
# 将结果存入Redis缓存
|
|
redis_asr_client.setex(cache_key, 3600, transcription) # 缓存1小时
|
|
|
|
# 发布结果到Redis频道
|
|
result_data = {
|
|
'client_id': client_id,
|
|
'transcription': transcription
|
|
}
|
|
redis_asr_client.publish('asr_results', json.dumps(result_data))
|
|
|
|
# 设置任务状态为 "completed"
|
|
redis_task_client.set(f"task_status:{cache_key}", "completed")
|
|
|
|
# 清理临时文件
|
|
os.remove(file_path)
|
|
|
|
except Exception as e:
|
|
print(f"处理音频文件时发生错误: {str(e)}")
|
|
# 设置任务状态为 "error"
|
|
redis_task_client.set(f"task_status:{cache_key}", "error")
|
|
|
|
def kafka_consumer():
|
|
consumer = KafkaConsumer(
|
|
KAFKA_TOPIC,
|
|
bootstrap_servers=[KAFKA_BROKER],
|
|
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
|
|
group_id='asr_group',
|
|
auto_offset_reset='earliest',
|
|
enable_auto_commit=True
|
|
)
|
|
|
|
print(f"ASR消费者已启动")
|
|
|
|
for message in consumer:
|
|
try:
|
|
task = message.value
|
|
file_path = task.get('file_path')
|
|
task_id = task.get('task_id')
|
|
status = task.get('status')
|
|
|
|
if not file_path or not task_id or status != 'queued':
|
|
print(f"收到无效任务: {task}")
|
|
continue
|
|
|
|
cache_key = f"asr:{task_id}"
|
|
client_id = task_id # 使用task_id作为client_id
|
|
|
|
print(f"开始处理任务: {cache_key}")
|
|
process_audio(file_path, client_id, cache_key)
|
|
print(f"完成处理任务: {cache_key}")
|
|
|
|
except Exception as e:
|
|
print(f"处理消息时发生错误: {str(e)}")
|
|
|
|
if __name__ == "__main__":
|
|
print("启动Kafka消费者处理ASR请求...")
|
|
kafka_consumer() |