Initial commit
This commit is contained in:
@@ -0,0 +1,94 @@
|
||||
# Kafka 配置
|
||||
KAFKA_BROKER=222.186.10.253:9092
|
||||
KAFKA_ASR_TOPIC=asr
|
||||
KAFKA_CHAT_TOPIC=chat
|
||||
KAFKA_TTS_TOPIC=tts
|
||||
|
||||
|
||||
# Redis 配置
|
||||
REDIS_HOST=150.158.144.159
|
||||
REDIS_PORT=13003
|
||||
REDIS_ASR_DB=12
|
||||
REDIS_CHAT_DB=13
|
||||
REDIS_TTS_DB=14
|
||||
REDIS_PASSWORD=Obscura@2024
|
||||
REDIS_API_DB=2
|
||||
REDIS_API_USAGE_DB=3
|
||||
REDIS_TASK_DB=11
|
||||
REDIS_SESSION_DB=63
|
||||
|
||||
REDIS_SESSION_DB_ZH=63
|
||||
REDIS_SESSION_DB_EN=62
|
||||
REDIS_SESSION_DB_KO=61
|
||||
|
||||
# CORS 配置
|
||||
# ALLOWED_ORIGINS=https://beta.obscura.work
|
||||
|
||||
|
||||
# GPT-SoVITS 配置
|
||||
GPT_MODEL_PATH=GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||
SOVITS_MODEL_PATH=GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||
REF_AUDIO_PATH=sample/woman.wav
|
||||
REF_TEXT_PATH=sample/woman.txt
|
||||
REF_LANGUAGE=中文
|
||||
TARGET_LANGUAGE=多语种混合
|
||||
OUTPUT_PATH=/obscura/task/audio_files
|
||||
|
||||
# VOICE_CONFIGS
|
||||
GIRL_REF_AUDIO=sample/gril.wav
|
||||
GIRL_REF_TEXT=sample/gril.txt
|
||||
|
||||
WOMAN_REF_AUDIO=sample/woman.wav
|
||||
WOMAN_REF_TEXT=sample/woman.txt
|
||||
|
||||
|
||||
MAN_REF_AUDIO=sample/man.wav
|
||||
MAN_REF_TEXT=sample/man.txt
|
||||
|
||||
LEIJUN_REF_AUDIO=sample/leijun.wav
|
||||
LEIJUN_REF_TEXT=sample/leijun.txt
|
||||
|
||||
DUFU_REF_AUDIO=sample/dufu.wav
|
||||
DUFU_REF_TEXT=sample/dufu.txt
|
||||
|
||||
HEJIONG_REF_AUDIO=sample/hejiong.wav
|
||||
HEJIONG_REF_TEXT=sample/hejiong.txt
|
||||
|
||||
MAHUATENG_REF_AUDIO=sample/mahuateng.wav
|
||||
MAHUATENG_REF_TEXT=sample/mahuateng.txt
|
||||
|
||||
LIDAN_REF_AUDIO=sample/lidan.wav
|
||||
LIDAN_REF_TEXT=sample/lidan.txt
|
||||
|
||||
YUHUA_REF_AUDIO=sample/yuhua.wav
|
||||
YUHUA_REF_TEXT=sample/yuhua.txt
|
||||
|
||||
LIUZHENYUN_REF_AUDIO=sample/liuzhenyun.wav
|
||||
LIUZHENYUN_REF_TEXT=sample/liuzhenyun.txt
|
||||
|
||||
DABING_REF_AUDIO=sample/dabing.wav
|
||||
DABING_REF_TEXT=sample/dabing.txt
|
||||
|
||||
LUOXIANG_REF_AUDIO=sample/luoxiang.wav
|
||||
LUOXIANG_REF_TEXT=sample/luoxiang.txt
|
||||
|
||||
XUZHIYUAN_REF_AUDIO=sample/xuzhiyuan.wav
|
||||
XUZHIYUAN_REF_TEXT=sample/xuzhiyuan.txt
|
||||
|
||||
|
||||
REDIS_GIRL_DB = 15
|
||||
REDIS_WOMAN_DB = 16
|
||||
REDIS_MAN_DB = 17
|
||||
REDIS_LEIJUN_DB = 18
|
||||
REDIS_DUFU_DB = 19
|
||||
REDIS_HEJIONG_DB = 20
|
||||
REDIS_MAHUATENG_DB = 21
|
||||
REDIS_LIDAN_DB = 22
|
||||
REDIS_DABING_DB = 23
|
||||
REDIS_LUOXIANG_DB = 24
|
||||
REDIS_XUZHIYUAN_DB = 25
|
||||
REDIS_YUHUA_DB = 26
|
||||
REDIS_LIUZHENYUN_DB = 27
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+110
@@ -0,0 +1,110 @@
|
||||
import whisper
|
||||
import os
|
||||
import json
|
||||
import redis
|
||||
from dotenv import load_dotenv
|
||||
from kafka import KafkaConsumer
|
||||
import asyncio
|
||||
|
||||
# 设置要使用的GPU ID
|
||||
GPU_ID = 1 # 修改这个值来选择要使用的GPU
|
||||
|
||||
# 设置CUDA_VISIBLE_DEVICES环境变量
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
print("正在加载Whisper模型...")
|
||||
model = whisper.load_model("large-v3")
|
||||
print("Whisper模型加载完成。")
|
||||
|
||||
# Kafka配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
||||
|
||||
# Redis配置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_asr_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_ASR_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
redis_task_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_TASK_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
async def process_audio(file_path: str, cache_key: str):
|
||||
try:
|
||||
# 更新任务状态
|
||||
redis_task_client.set(f"task_status:{cache_key}", "processing")
|
||||
|
||||
result = model.transcribe(file_path)
|
||||
transcription = result['text']
|
||||
|
||||
print(f"处理了文件: {file_path}")
|
||||
print(f"转录结果: {transcription}")
|
||||
|
||||
redis_asr_client.setex(cache_key, 3600, transcription)
|
||||
|
||||
result_data = {
|
||||
'transcription': transcription
|
||||
}
|
||||
redis_asr_client.publish('asr_results', json.dumps(result_data))
|
||||
|
||||
# 更新任务状态
|
||||
redis_task_client.set(f"task_status:{cache_key}", "completed")
|
||||
|
||||
os.remove(file_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理音频文件时发生错误: {str(e)}")
|
||||
# 更新任务状态
|
||||
redis_task_client.set(f"task_status:{cache_key}", "error")
|
||||
|
||||
async def kafka_consumer():
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
|
||||
group_id='asr_group',
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True
|
||||
)
|
||||
|
||||
print(f"ASR消费者已启动")
|
||||
|
||||
for message in consumer:
|
||||
try:
|
||||
task = message.value
|
||||
file_path = task.get('file_path')
|
||||
task_id = task.get('task_id')
|
||||
status = task.get('status')
|
||||
|
||||
if not file_path or not task_id or status != 'queued':
|
||||
print(f"收到无效任务: {task}")
|
||||
continue
|
||||
|
||||
cache_key = f"asr:{task_id}"
|
||||
|
||||
print(f"开始处理任务: {cache_key}")
|
||||
await process_audio(file_path, cache_key)
|
||||
print(f"完成处理任务: {cache_key}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理消息时发生错误: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("启动Kafka消费者处理ASR请求...")
|
||||
asyncio.run(kafka_consumer())
|
||||
@@ -0,0 +1,115 @@
|
||||
import whisper
|
||||
import os
|
||||
import json
|
||||
import redis
|
||||
from dotenv import load_dotenv
|
||||
from kafka import KafkaConsumer
|
||||
import threading
|
||||
|
||||
# 设置要使用的GPU ID
|
||||
GPU_ID = 1 # 修改这个值来选择要使用的GPU
|
||||
|
||||
# 设置CUDA_VISIBLE_DEVICES环境变量
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
print("正在加载Whisper模型...")
|
||||
model = whisper.load_model("large-v3")
|
||||
print("Whisper模型加载完成。")
|
||||
|
||||
# Kafka配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
||||
|
||||
# Redis配置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_asr_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_ASR_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
redis_task_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_TASK_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
def process_audio(file_path: str, client_id: str, cache_key: str):
|
||||
try:
|
||||
# 设置任务状态为 "processing"
|
||||
redis_task_client.set(f"task_status:{cache_key}", "processing")
|
||||
|
||||
result = model.transcribe(file_path)
|
||||
transcription = result['text']
|
||||
|
||||
print(f"处理了文件: {file_path}")
|
||||
print(f"转录结果: {transcription}")
|
||||
|
||||
# 将结果存入Redis缓存
|
||||
redis_asr_client.setex(cache_key, 3600, transcription) # 缓存1小时
|
||||
|
||||
# 发布结果到Redis频道
|
||||
result_data = {
|
||||
'client_id': client_id,
|
||||
'transcription': transcription
|
||||
}
|
||||
redis_asr_client.publish('asr_results', json.dumps(result_data))
|
||||
|
||||
# 设置任务状态为 "completed"
|
||||
redis_task_client.set(f"task_status:{cache_key}", "completed")
|
||||
|
||||
# 清理临时文件
|
||||
os.remove(file_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理音频文件时发生错误: {str(e)}")
|
||||
# 设置任务状态为 "error"
|
||||
redis_task_client.set(f"task_status:{cache_key}", "error")
|
||||
|
||||
def kafka_consumer():
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
|
||||
group_id='asr_group',
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True
|
||||
)
|
||||
|
||||
print(f"ASR消费者已启动")
|
||||
|
||||
for message in consumer:
|
||||
try:
|
||||
task = message.value
|
||||
file_path = task.get('file_path')
|
||||
task_id = task.get('task_id')
|
||||
status = task.get('status')
|
||||
|
||||
if not file_path or not task_id or status != 'queued':
|
||||
print(f"收到无效任务: {task}")
|
||||
continue
|
||||
|
||||
cache_key = f"asr:{task_id}"
|
||||
client_id = task_id # 使用task_id作为client_id
|
||||
|
||||
print(f"开始处理任务: {cache_key}")
|
||||
process_audio(file_path, client_id, cache_key)
|
||||
print(f"完成处理任务: {cache_key}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理消息时发生错误: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("启动Kafka消费者处理ASR请求...")
|
||||
kafka_consumer()
|
||||
@@ -0,0 +1,117 @@
|
||||
from kafka import KafkaConsumer
|
||||
import json
|
||||
import asyncio
|
||||
import redis
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import requests
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
|
||||
# Kafka 设置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
|
||||
KAFKA_CONSUMER_GROUP = 'chat_group'
|
||||
KAFKA_CONSUMER_NUM = 3 # 消费者数量
|
||||
|
||||
# Redis 设置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_CHAT_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
redis_task_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_TASK_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
|
||||
|
||||
# 创建Kafka消费者
|
||||
def create_kafka_consumer():
|
||||
return KafkaConsumer(
|
||||
KAFKA_CHAT_TOPIC,
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
auto_offset_reset='latest',
|
||||
enable_auto_commit=True,
|
||||
group_id=KAFKA_CONSUMER_GROUP,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
async def process_chat_request(chat_request):
|
||||
try:
|
||||
session_id = chat_request['session_id']
|
||||
query = chat_request['query']
|
||||
model = chat_request.get('model', 'qwen2.5:3b')
|
||||
|
||||
# 设置任务状态为 "processing"
|
||||
redis_task_client.set(f"task_status:{session_id}", "processing")
|
||||
|
||||
# 从Redis获取历史记录
|
||||
history = json.loads(redis_client.get(session_id) or '[]')
|
||||
|
||||
# 构建包含历史对话的完整提示
|
||||
full_prompt = DEFAULT_SYSTEM_PROMPT + "\n"
|
||||
for past_query, past_response in history:
|
||||
full_prompt += f"用户: {past_query}\n助手: {past_response}\n"
|
||||
full_prompt += f"用户: {query}\n助手:"
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": full_prompt,
|
||||
"stream": True,
|
||||
"temperature": 0
|
||||
}
|
||||
|
||||
response = requests.post("http://127.0.0.1:11434/api/generate", json=data, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
text_output = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
json_data = json.loads(line)
|
||||
if 'response' in json_data:
|
||||
text_output += json_data['response']
|
||||
|
||||
# 更新历史记录
|
||||
history.append((query, text_output))
|
||||
redis_client.set(session_id, json.dumps(history))
|
||||
|
||||
# 设置任务状态为 "completed"
|
||||
redis_task_client.set(f"task_status:{session_id}", "completed")
|
||||
|
||||
print(f"处理完成 session {session_id}: {text_output}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理 session {chat_request['session_id']} 时出错: {str(e)}")
|
||||
# 设置任务状态为 "error"
|
||||
redis_task_client.set(f"task_status:{chat_request['session_id']}", "error")
|
||||
|
||||
def kafka_consumer_thread(consumer_id):
|
||||
consumer = create_kafka_consumer()
|
||||
print(f"消费者 {consumer_id} 已启动")
|
||||
for message in consumer:
|
||||
chat_request = message.value
|
||||
asyncio.run(process_chat_request(chat_request))
|
||||
|
||||
def main():
|
||||
print("启动Kafka消费者处理聊天请求...")
|
||||
with ThreadPoolExecutor(max_workers=KAFKA_CONSUMER_NUM) as executor:
|
||||
for i in range(KAFKA_CONSUMER_NUM):
|
||||
executor.submit(kafka_consumer_thread, i)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,182 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import requests
|
||||
import json
|
||||
from typing import List, Tuple
|
||||
from kafka import KafkaConsumer, TopicPartition
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
import asyncio
|
||||
import redis
|
||||
import uuid
|
||||
import logging
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import torch
|
||||
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
app = FastAPI()
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
torch.cuda.set_device(device)
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load MiniCPM3-4B model
|
||||
path = "/home/zydi/worker_chat/api/OpenBMB/MiniCPM3-4B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True)
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
# CORS 配置
|
||||
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Kafka 设置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_MINI3_TOPIC')
|
||||
KAFKA_CONSUMER_GROUP = 'mini3_group'
|
||||
KAFKA_CONSUMER_NUM = 1
|
||||
|
||||
# Redis 设置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_DB = int(os.getenv('REDIS_MINI3_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB,
|
||||
password=REDIS_PASSWORD # 使用密码进行认证
|
||||
)
|
||||
# 创建Kafka消费者
|
||||
def create_kafka_consumer():
|
||||
return KafkaConsumer(
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
group_id=KAFKA_CONSUMER_GROUP,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# Kafka消费者函数
|
||||
def kafka_consumer(consumer, consumer_id):
|
||||
# 获取消费者分配的分区
|
||||
consumer.subscribe([KAFKA_TOPIC])
|
||||
partitions = consumer.assignment()
|
||||
|
||||
logger.info(f"消费者 {consumer_id} 被分配了以下分区: {[p.partition for p in partitions]}")
|
||||
|
||||
for message in consumer:
|
||||
partition = message.partition
|
||||
offset = message.offset
|
||||
chat_request = message.value # 直接使用 message.value,它已经是一个字典
|
||||
session_id = chat_request['session_id']
|
||||
query = chat_request['query']
|
||||
|
||||
logger.info(f"消费者 {consumer_id} 正在处理来自分区 {partition} 的消息:")
|
||||
|
||||
asyncio.run(process_chat_request(chat_request))
|
||||
|
||||
# 启动Kafka消费者线程
|
||||
def start_kafka_consumers(num_consumers=KAFKA_CONSUMER_NUM):
|
||||
consumers = []
|
||||
for i in range(num_consumers):
|
||||
consumer = create_kafka_consumer()
|
||||
consumer_thread = threading.Thread(target=kafka_consumer, args=(consumer, i), daemon=True)
|
||||
consumer_thread.start()
|
||||
consumers.append((consumer, consumer_thread))
|
||||
return consumers
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str
|
||||
query: str
|
||||
model: str = "minicpm3-4b"
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
response: str
|
||||
history: List[Tuple[str, str]]
|
||||
|
||||
# 处理聊天请求的异步函数
|
||||
async def process_chat_request(chat_request):
|
||||
try:
|
||||
response = await chat(ChatRequest(**chat_request))
|
||||
print(f"Processed message for session {chat_request['session_id']}: {response}")
|
||||
except Exception as e:
|
||||
print(f"Error processing message for session {chat_request['session_id']}: {str(e)}")
|
||||
|
||||
@app.post("/mini3", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest):
|
||||
session_id = request.session_id
|
||||
query = request.query
|
||||
|
||||
# 从Redis获取历史记录
|
||||
history = json.loads(redis_client.get(session_id) or '[]')
|
||||
|
||||
# 构建包含历史对话的完整提示
|
||||
messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
|
||||
for past_query, past_response in history:
|
||||
messages.append({"role": "user", "content": past_query})
|
||||
messages.append({"role": "assistant", "content": past_response})
|
||||
messages.append({"role": "user", "content": query})
|
||||
|
||||
try:
|
||||
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
|
||||
|
||||
# 创建注意力掩码
|
||||
attention_mask = model_inputs.ne(tokenizer.pad_token_id).long()
|
||||
|
||||
# 将输入移动到正确的设备(CPU或GPU)
|
||||
model_inputs = model_inputs.to(device)
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
model_outputs = model.generate(
|
||||
model_inputs,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=1024,
|
||||
top_p=0.7,
|
||||
temperature=0.7,
|
||||
pad_token_id=tokenizer.eos_token_id, # 将pad_token_id设置为eos_token_id
|
||||
do_sample=True
|
||||
)
|
||||
|
||||
output_token_ids = model_outputs[0][len(model_inputs[0]):]
|
||||
text_output = tokenizer.decode(output_token_ids, skip_special_tokens=True)
|
||||
|
||||
# 更新历史记录
|
||||
history.append((query, text_output))
|
||||
redis_client.set(session_id, json.dumps(history))
|
||||
|
||||
return ChatResponse(response=text_output, history=history)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/start_chat")
|
||||
async def start_chat():
|
||||
session_id = str(uuid.uuid4())
|
||||
return {"session_id": session_id}
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 启动Kafka消费者线程
|
||||
start_kafka_consumers()
|
||||
|
||||
# 启动FastAPI服务器
|
||||
uvicorn.run(app, host="0.0.0.0", port=6003)
|
||||
@@ -0,0 +1,170 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import requests
|
||||
import json
|
||||
from typing import List, Tuple
|
||||
from kafka import KafkaConsumer, TopicPartition
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
import asyncio
|
||||
import redis
|
||||
import uuid
|
||||
import logging
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import torch
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
app = FastAPI()
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
torch.cuda.set_device(device)
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
# CORS 配置
|
||||
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Kafka 设置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
|
||||
KAFKA_CONSUMER_GROUP = 'chat_group'
|
||||
KAFKA_CONSUMER_NUM = 1
|
||||
|
||||
# Redis 设置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_DB = int(os.getenv('REDIS_CHAT_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB,
|
||||
password=REDIS_PASSWORD # 使用密码进行认证
|
||||
)
|
||||
# 创建Kafka消费者
|
||||
def create_kafka_consumer():
|
||||
return KafkaConsumer(
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
group_id=KAFKA_CONSUMER_GROUP,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# Kafka消费者函数
|
||||
def kafka_consumer(consumer, consumer_id):
|
||||
# 获取消费者分配的分区
|
||||
consumer.subscribe([KAFKA_TOPIC])
|
||||
partitions = consumer.assignment()
|
||||
|
||||
logger.info(f"消费者 {consumer_id} 被分配了以下分区: {[p.partition for p in partitions]}")
|
||||
|
||||
for message in consumer:
|
||||
partition = message.partition
|
||||
offset = message.offset
|
||||
chat_request = message.value # 直接使用 message.value,它已经是一个字典
|
||||
session_id = chat_request['session_id']
|
||||
query = chat_request['query']
|
||||
|
||||
logger.info(f"消费者 {consumer_id} 正在处理来自分区 {partition} 的消息:")
|
||||
|
||||
asyncio.run(process_chat_request(chat_request))
|
||||
|
||||
# 启动Kafka消费者线程
|
||||
def start_kafka_consumers(num_consumers=KAFKA_CONSUMER_NUM):
|
||||
consumers = []
|
||||
for i in range(num_consumers):
|
||||
consumer = create_kafka_consumer()
|
||||
consumer_thread = threading.Thread(target=kafka_consumer, args=(consumer, i), daemon=True)
|
||||
consumer_thread.start()
|
||||
consumers.append((consumer, consumer_thread))
|
||||
return consumers
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str
|
||||
query: str
|
||||
model: str = "qwen2.5:3b"
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
response: str
|
||||
history: List[Tuple[str, str]]
|
||||
|
||||
# 处理聊天请求的异步函数
|
||||
async def process_chat_request(chat_request):
|
||||
try:
|
||||
response = await chat(ChatRequest(**chat_request))
|
||||
print(f"Processed message for session {chat_request['session_id']}: {response}")
|
||||
except Exception as e:
|
||||
print(f"Error processing message for session {chat_request['session_id']}: {str(e)}")
|
||||
|
||||
@app.post("/chat", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest):
|
||||
session_id = request.session_id
|
||||
query = request.query
|
||||
model = request.model
|
||||
|
||||
# 从Redis获取历史记录
|
||||
history = json.loads(redis_client.get(session_id) or '[]')
|
||||
|
||||
# 构建包含历史对话的完整提示
|
||||
full_prompt = DEFAULT_SYSTEM_PROMPT + "\n"
|
||||
for past_query, past_response in history:
|
||||
full_prompt += f"用户: {past_query}\n助手: {past_response}\n"
|
||||
full_prompt += f"用户: {query}"
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": full_prompt,
|
||||
"stream": True,
|
||||
"temperature": 0
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post("http://127.0.0.1:11434/api/generate", json=data, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
text_output = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
json_data = json.loads(line)
|
||||
if 'response' in json_data:
|
||||
text_output += json_data['response']
|
||||
|
||||
# 更新历史记录
|
||||
history.append((query, text_output))
|
||||
redis_client.set(session_id, json.dumps(history))
|
||||
|
||||
return ChatResponse(response=text_output, history=history)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/start_chat")
|
||||
async def start_chat():
|
||||
session_id = str(uuid.uuid4())
|
||||
return {"session_id": session_id}
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 启动Kafka消费者线程
|
||||
start_kafka_consumers()
|
||||
|
||||
# 启动FastAPI服务器
|
||||
uvicorn.run(app, host="0.0.0.0", port=6001)
|
||||
@@ -0,0 +1,319 @@
|
||||
from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security import APIKeyHeader
|
||||
from pydantic import BaseModel
|
||||
from kafka import KafkaProducer
|
||||
from redis import Redis
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from dotenv import load_dotenv
|
||||
import tempfile
|
||||
import hashlib
|
||||
import asyncio
|
||||
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI()
|
||||
v1_chat_app = FastAPI()
|
||||
app.mount("/v1_chat", v1_chat_app)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB'))
|
||||
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
|
||||
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
|
||||
REDIS_API_DB = int(os.getenv('REDIS_API_DB'))
|
||||
REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB'))
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||
|
||||
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||
KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
||||
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
|
||||
|
||||
# 初始化 Kafka Producer
|
||||
producer = KafkaProducer(
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB)
|
||||
redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB)
|
||||
redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB)
|
||||
redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB)
|
||||
redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB)
|
||||
redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB)
|
||||
|
||||
# 定义API密钥头部
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
def get_audio_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# 验证API密钥的函数
|
||||
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||
if api_key and api_key.startswith("Bearer "):
|
||||
key = api_key.split(" ")[1]
|
||||
if key.startswith("obs-"):
|
||||
return key
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的API密钥",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
"api_key": api_key,
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str):
|
||||
api_key = api_key_info['api_key']
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
# 发送任务到Kafka
|
||||
producer.send(kafka_topic, task_data)
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return {
|
||||
"message": f"{model_name.upper()}请求已排队等待处理",
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
}
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str
|
||||
query: str
|
||||
model: str = "qwen2.5:3b"
|
||||
|
||||
|
||||
# 添加WebSocket连接管理
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, client_id: str):
|
||||
await websocket.accept()
|
||||
self.active_connections[client_id] = websocket
|
||||
|
||||
def disconnect(self, client_id: str):
|
||||
self.active_connections.pop(client_id, None)
|
||||
|
||||
async def send_message(self, message: str, client_id: str):
|
||||
if client_id in self.active_connections:
|
||||
await self.active_connections[client_id].send_text(message)
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
@v1_chat_app.websocket("/ws/{client_id}")
|
||||
async def websocket_endpoint(websocket: WebSocket, client_id: str):
|
||||
await manager.connect(websocket, client_id)
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(client_id)
|
||||
|
||||
# 修改TTS请求处理函数
|
||||
@v1_chat_app.post("/tts")
|
||||
async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
task_data = {
|
||||
"task_id": task_id,
|
||||
"text": request.text,
|
||||
"status": "queued",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
redis_task_client.set(f"task_status:{task_id}", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "tts", 100, task_data, KAFKA_TTS_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
|
||||
# 将任务ID存储到Redis,以便后续WebSocket通信使用
|
||||
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
|
||||
|
||||
return result
|
||||
|
||||
# 修改ASR请求处理函数
|
||||
@v1_chat_app.post("/asr")
|
||||
async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
UPLOAD_DIR = "/obscura/task/audio_upload"
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav")
|
||||
|
||||
with open(file_path, "wb") as temp_audio:
|
||||
content = await audio.read()
|
||||
temp_audio.write(content)
|
||||
|
||||
task_data = {
|
||||
'file_path': file_path,
|
||||
'task_id': task_id,
|
||||
'status': 'queued'
|
||||
}
|
||||
|
||||
redis_task_client.set(f"task_status:{task_id}", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "asr", 100, task_data, KAFKA_ASR_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
|
||||
# 将任务ID存储到Redis,以便后续WebSocket通信使用
|
||||
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
|
||||
|
||||
return result
|
||||
|
||||
# 修改聊天请求处理函数
|
||||
@v1_chat_app.post("/chat")
|
||||
async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
task_data = {
|
||||
"task_id": task_id,
|
||||
"session_id": request.session_id,
|
||||
"query": request.query,
|
||||
"model": request.model,
|
||||
"status": "queued",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
redis_task_client.set(f"task_status:{task_id}", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "chat", 100, task_data, KAFKA_CHAT_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
|
||||
# 将任务ID存储到Redis,以便后续WebSocket通信使用
|
||||
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
|
||||
|
||||
return result
|
||||
|
||||
@v1_chat_app.get("/chat_result/{task_id}")
|
||||
async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
# 从Redis任务数据库获取任务状态
|
||||
task_status = redis_task_client.get(f"task_status:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从Redis聊天结果数据库获取聊天结果
|
||||
chat_result = redis_chat_client.get(task_id)
|
||||
if chat_result:
|
||||
result = json.loads(chat_result)
|
||||
return {
|
||||
"status": "completed",
|
||||
"history": result # 直接返回整个历史记录
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
@v1_chat_app.get("/tts_result/{task_id}")
|
||||
async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
# 从Redis任务数据库获取任务状态
|
||||
task_status = redis_task_client.get(f"task_status:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从Redis TTS结果数据库获取音频文件路径
|
||||
audio_info = redis_tts_client.get(task_id)
|
||||
if audio_info:
|
||||
audio_path = json.loads(audio_info)['path']
|
||||
return {
|
||||
"status": "completed",
|
||||
"audio_path": audio_path
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
@v1_chat_app.get("/asr_result/{task_id}")
|
||||
async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
# 从Redis任务数据库获取任务状态
|
||||
task_status = redis_task_client.get(f"task_status:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从Redis ASR结果数据库获取转录结果
|
||||
transcription = redis_asr_client.get(task_id)
|
||||
if transcription:
|
||||
return {
|
||||
"status": "completed",
|
||||
"transcription": transcription.decode('utf-8')
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8008)
|
||||
@@ -0,0 +1,180 @@
|
||||
import os
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
import uvicorn
|
||||
import redis
|
||||
import hashlib
|
||||
import json
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
import threading
|
||||
import time
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import torch
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# FastAPI configuration
|
||||
app = FastAPI()
|
||||
i18n = I18nAuto()
|
||||
|
||||
# CORS configuration
|
||||
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
||||
|
||||
# Redis configuration
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_DB = int(os.getenv('REDIS_TTS_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# Kafka configuration
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||
# KAFKA_GROUP_ID = 'tts_group'
|
||||
KAFKA_CONSUMER_THREADS = 1
|
||||
|
||||
# TTS configuration
|
||||
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
|
||||
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
|
||||
REF_AUDIO_PATH = os.getenv('REF_AUDIO_PATH')
|
||||
REF_TEXT_PATH = os.getenv('REF_TEXT_PATH')
|
||||
REF_LANGUAGE = os.getenv('REF_LANGUAGE')
|
||||
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
|
||||
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
|
||||
|
||||
# Initialize FastAPI CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize Redis client
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
# Initialize Kafka producer
|
||||
kafka_producer = KafkaProducer(bootstrap_servers=KAFKA_BROKER)
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str = Field(..., alias="text")
|
||||
|
||||
def get_audio_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# Initialize models at startup
|
||||
print("Initializing models...")
|
||||
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
|
||||
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
|
||||
|
||||
# Read reference text
|
||||
with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file:
|
||||
ref_text = file.read()
|
||||
|
||||
print("Models initialized successfully.")
|
||||
|
||||
def synthesize(target_text, output_path):
|
||||
# Synthesize audio
|
||||
with torch.cuda.device(device):
|
||||
synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(REF_LANGUAGE),
|
||||
text=target_text,
|
||||
text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
audio_hash = get_audio_hash(target_text)
|
||||
output_wav_path = os.path.join(output_path, f"{audio_hash}.wav")
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
return output_wav_path
|
||||
else:
|
||||
return None
|
||||
|
||||
@app.post("/tts")
|
||||
async def synthesize_audio(request: TTSRequest):
|
||||
try:
|
||||
print(f"Received TTS request: {request.dict()}")
|
||||
target_text = request.text
|
||||
audio_hash = get_audio_hash(target_text)
|
||||
|
||||
# Check Redis cache
|
||||
cached_audio = redis_client.get(audio_hash)
|
||||
if cached_audio:
|
||||
audio_info = json.loads(cached_audio)
|
||||
return FileResponse(audio_info['path'], media_type="audio/wav")
|
||||
|
||||
# Check file system
|
||||
file_path = os.path.join(OUTPUT_PATH, f"{audio_hash}.wav")
|
||||
if os.path.exists(file_path):
|
||||
# Cache the file path in Redis
|
||||
redis_client.set(audio_hash, json.dumps({"path": file_path}))
|
||||
return FileResponse(file_path, media_type="audio/wav")
|
||||
|
||||
# Send message to Kafka
|
||||
kafka_producer.send(KAFKA_TOPIC, json.dumps({
|
||||
'text': target_text,
|
||||
'audio_hash': audio_hash
|
||||
}).encode('utf-8'))
|
||||
|
||||
# Wait for the audio to be generated (you might want to implement a more sophisticated waiting mechanism)
|
||||
for _ in range(60): # Wait for up to 30 seconds
|
||||
if os.path.exists(file_path):
|
||||
return FileResponse(file_path, media_type="audio/wav")
|
||||
time.sleep(1)
|
||||
|
||||
# If audio is not generated within the timeout
|
||||
raise HTTPException(status_code=504, detail="Audio generation timed out")
|
||||
except Exception as e:
|
||||
print(f"Error processing TTS request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "TTS API is running"}
|
||||
|
||||
def kafka_consumer_thread():
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
# group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='latest',
|
||||
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
|
||||
)
|
||||
|
||||
for message in consumer:
|
||||
target_text = message.value['text']
|
||||
audio_hash = message.value['audio_hash']
|
||||
|
||||
output_path = synthesize(target_text, OUTPUT_PATH)
|
||||
|
||||
if output_path:
|
||||
redis_client.set(audio_hash, json.dumps({"path": output_path}))
|
||||
print(f"Audio synthesized successfully: {output_path}")
|
||||
else:
|
||||
print("Failed to synthesize audio")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Start Kafka consumer threads
|
||||
torch.cuda.set_device(device)
|
||||
for _ in range(KAFKA_CONSUMER_THREADS):
|
||||
consumer_thread = threading.Thread(target=kafka_consumer_thread)
|
||||
consumer_thread.start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=6002)
|
||||
@@ -0,0 +1,180 @@
|
||||
import os
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
import uvicorn
|
||||
import redis
|
||||
import hashlib
|
||||
import json
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
import threading
|
||||
import time
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import torch
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# FastAPI configuration
|
||||
app = FastAPI()
|
||||
i18n = I18nAuto()
|
||||
|
||||
# CORS configuration
|
||||
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
||||
|
||||
# Redis configuration
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_DB = int(os.getenv('REDIS_TTS_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# Kafka configuration
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||
# KAFKA_GROUP_ID = 'tts_group'
|
||||
KAFKA_CONSUMER_THREADS = 1
|
||||
|
||||
# TTS configuration
|
||||
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
|
||||
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
|
||||
REF_AUDIO_PATH = os.getenv('REF_AUDIO_KO_PATH')
|
||||
REF_TEXT_PATH = os.getenv('REF_TEXT_KO_PATH')
|
||||
REF_LANGUAGE = os.getenv('REF_KO_LANGUAGE')
|
||||
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
|
||||
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
|
||||
|
||||
# Initialize FastAPI CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize Redis client
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
# Initialize Kafka producer
|
||||
kafka_producer = KafkaProducer(bootstrap_servers=KAFKA_BROKER)
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str = Field(..., alias="text")
|
||||
|
||||
def get_audio_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# Initialize models at startup
|
||||
print("Initializing models...")
|
||||
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
|
||||
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
|
||||
|
||||
# Read reference text
|
||||
with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file:
|
||||
ref_text = file.read()
|
||||
|
||||
print("Models initialized successfully.")
|
||||
|
||||
def synthesize(target_text, output_path):
|
||||
# Synthesize audio
|
||||
with torch.cuda.device(device):
|
||||
synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(REF_LANGUAGE),
|
||||
text=target_text,
|
||||
text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
audio_hash = get_audio_hash(target_text)
|
||||
output_wav_path = os.path.join(output_path, f"{audio_hash}.wav")
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
return output_wav_path
|
||||
else:
|
||||
return None
|
||||
|
||||
@app.post("/tts_ko")
|
||||
async def synthesize_audio(request: TTSRequest):
|
||||
try:
|
||||
print(f"Received TTS request: {request.dict()}")
|
||||
target_text = request.text
|
||||
audio_hash = get_audio_hash(target_text)
|
||||
|
||||
# Check Redis cache
|
||||
cached_audio = redis_client.get(audio_hash)
|
||||
if cached_audio:
|
||||
audio_info = json.loads(cached_audio)
|
||||
return FileResponse(audio_info['path'], media_type="audio/wav")
|
||||
|
||||
# Check file system
|
||||
file_path = os.path.join(OUTPUT_PATH, f"{audio_hash}.wav")
|
||||
if os.path.exists(file_path):
|
||||
# Cache the file path in Redis
|
||||
redis_client.set(audio_hash, json.dumps({"path": file_path}))
|
||||
return FileResponse(file_path, media_type="audio/wav")
|
||||
|
||||
# Send message to Kafka
|
||||
kafka_producer.send(KAFKA_TOPIC, json.dumps({
|
||||
'text': target_text,
|
||||
'audio_hash': audio_hash
|
||||
}).encode('utf-8'))
|
||||
|
||||
# Wait for the audio to be generated (you might want to implement a more sophisticated waiting mechanism)
|
||||
for _ in range(60): # Wait for up to 30 seconds
|
||||
if os.path.exists(file_path):
|
||||
return FileResponse(file_path, media_type="audio/wav")
|
||||
time.sleep(1)
|
||||
|
||||
# If audio is not generated within the timeout
|
||||
raise HTTPException(status_code=504, detail="Audio generation timed out")
|
||||
except Exception as e:
|
||||
print(f"Error processing TTS request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "TTS API is running"}
|
||||
|
||||
def kafka_consumer_thread():
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
# group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='latest',
|
||||
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
|
||||
)
|
||||
|
||||
for message in consumer:
|
||||
target_text = message.value['text']
|
||||
audio_hash = message.value['audio_hash']
|
||||
|
||||
output_path = synthesize(target_text, OUTPUT_PATH)
|
||||
|
||||
if output_path:
|
||||
redis_client.set(audio_hash, json.dumps({"path": output_path}))
|
||||
print(f"Audio synthesized successfully: {output_path}")
|
||||
else:
|
||||
print("Failed to synthesize audio")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Start Kafka consumer threads
|
||||
torch.cuda.set_device(device)
|
||||
for _ in range(KAFKA_CONSUMER_THREADS):
|
||||
consumer_thread = threading.Thread(target=kafka_consumer_thread)
|
||||
consumer_thread.start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=6003)
|
||||
@@ -0,0 +1,176 @@
|
||||
# 导入所需的库
|
||||
import os
|
||||
import soundfile as sf
|
||||
import redis
|
||||
import hashlib
|
||||
import json
|
||||
from kafka import KafkaConsumer
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
from dotenv import load_dotenv
|
||||
import torch
|
||||
|
||||
"""
|
||||
整体设计说明:
|
||||
这个脚本实现了一个文本到语音(TTS)的服务。它使用Kafka作为消息队列接收TTS任务,
|
||||
使用Redis存储任务状态和结果,并利用GPT-SoVITS模型进行语音合成。
|
||||
主要功能包括:
|
||||
1. 初始化配置和模型
|
||||
2. 提供语音合成功能
|
||||
3. 监听Kafka消息并处理TTS任务
|
||||
4. 将合成结果存储到Redis并更新任务状态
|
||||
"""
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 设置GPU设备(如果可用)
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
# 从环境变量中读取Redis配置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB')) # DB 2用于存储TTS结果
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) # DB 3用于存储任务状态
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# 从环境变量中读取Kafka配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||
|
||||
# 从环境变量中读取TTS相关配置
|
||||
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
|
||||
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
|
||||
REF_AUDIO_PATH = os.getenv('REF_AUDIO_ZN_PATH')
|
||||
REF_TEXT_PATH = os.getenv('REF_TEXT_ZN_PATH')
|
||||
REF_LANGUAGE = os.getenv('REF_LANGUAGE')
|
||||
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
|
||||
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
|
||||
|
||||
# 初始化Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_TTS_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
redis_task_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_TASK_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
# 初始化国际化工具
|
||||
i18n = I18nAuto()
|
||||
|
||||
def get_audio_hash(text):
|
||||
"""
|
||||
生成文本的MD5哈希值,用作音频文件名的一部分
|
||||
|
||||
参数:
|
||||
text (str): 需要生成哈希的文本
|
||||
|
||||
返回:
|
||||
str: 文本的MD5哈希值
|
||||
"""
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# 初始化模型
|
||||
print("正在初始化模型...")
|
||||
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
|
||||
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
|
||||
|
||||
# 读取参考文本
|
||||
with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file:
|
||||
ref_text = file.read()
|
||||
|
||||
print("模型初始化成功。")
|
||||
|
||||
def synthesize(target_text, output_wav_path):
|
||||
"""
|
||||
使用GPT-SoVITS模型合成语音
|
||||
|
||||
参数:
|
||||
target_text (str): 需要合成语音的目标文本
|
||||
output_wav_path (str): 输出音频文件的路径
|
||||
|
||||
返回:
|
||||
str: 如果成功,返回输出音频文件的路径;如果失败,返回None
|
||||
"""
|
||||
with torch.cuda.device(device):
|
||||
synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(REF_LANGUAGE),
|
||||
text=target_text,
|
||||
text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
return output_wav_path
|
||||
else:
|
||||
return None
|
||||
|
||||
def kafka_consumer():
|
||||
"""
|
||||
Kafka消费者函数,用于接收和处理TTS任务
|
||||
|
||||
该函数会持续监听Kafka的TTS主题,接收任务并进行处理:
|
||||
1. 接收任务信息
|
||||
2. 更新任务状态
|
||||
3. 调用synthesize函数合成语音
|
||||
4. 将结果保存到Redis
|
||||
5. 更新任务完成状态
|
||||
"""
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TTS_TOPIC,
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
auto_offset_reset='latest',
|
||||
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
|
||||
)
|
||||
print(f"TTS消费者已启动")
|
||||
for message in consumer:
|
||||
try:
|
||||
task_id = message.value['task_id']
|
||||
target_text = message.value['text']
|
||||
text_hash = message.value['text_hash']
|
||||
|
||||
# 更新任务状态为 "processing"
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "processing")
|
||||
|
||||
output_wav_path = os.path.join(OUTPUT_PATH, f"{text_hash}.wav")
|
||||
|
||||
# 再次检查文件是否存在(以防在此期间被其他进程创建)
|
||||
if not os.path.exists(output_wav_path):
|
||||
output_path = synthesize(target_text, output_wav_path)
|
||||
else:
|
||||
output_path = output_wav_path
|
||||
|
||||
if output_path:
|
||||
# 将结果保存在 DB 2
|
||||
redis_client.set(f"tts:{task_id}", json.dumps({"path": output_path}))
|
||||
print(f"音频合成成功: {output_path}")
|
||||
|
||||
# 更新任务状态为 "completed"
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "completed")
|
||||
else:
|
||||
print("音频合成失败")
|
||||
|
||||
# 更新任务状态为 "failed"
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "failed")
|
||||
except Exception as e:
|
||||
print(f"处理消息时出错: {str(e)}")
|
||||
|
||||
# 更新任务状态为 "failed"
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "failed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 设置CUDA设备
|
||||
torch.cuda.set_device(device)
|
||||
# 启动Kafka消费者
|
||||
kafka_consumer()
|
||||
@@ -0,0 +1,186 @@
|
||||
from fastapi import FastAPI, File, UploadFile, HTTPException, WebSocket
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import whisper
|
||||
import tempfile
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
import json
|
||||
import threading
|
||||
import os
|
||||
import uuid
|
||||
import asyncio
|
||||
import logging
|
||||
import redis
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 设置要使用的GPU ID
|
||||
GPU_ID = 1 # 修改这个值来选择要使用的GPU
|
||||
|
||||
# 设置CUDA_VISIBLE_DEVICES环境变量
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# CORS 配置
|
||||
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
||||
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
print("正在加载Whisper模型...")
|
||||
model = whisper.load_model("large-v3")
|
||||
print("Whisper模型加载完成。")
|
||||
|
||||
# Kafka配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
||||
|
||||
# Redis配置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_DB = int(os.getenv('REDIS_ASR_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB,
|
||||
password=REDIS_PASSWORD # 添加密码
|
||||
)
|
||||
|
||||
# Kafka生产者
|
||||
producer = KafkaProducer(
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
||||
)
|
||||
|
||||
# 存储WebSocket连接的字典
|
||||
active_connections = {}
|
||||
|
||||
@app.websocket("/asr/ws/{client_id}")
|
||||
async def websocket_endpoint(websocket: WebSocket, client_id: str):
|
||||
await websocket.accept()
|
||||
active_connections[client_id] = websocket
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# 设置接收超时
|
||||
data = await asyncio.wait_for(websocket.receive_text(), timeout=30)
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
else:
|
||||
await websocket.send_text(f"收到消息: {data}")
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
# 发送心跳
|
||||
await websocket.send_text("heartbeat")
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"客户端 {client_id} 断开连接")
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"客户端 {client_id} 断开连接")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket错误: {e}")
|
||||
finally:
|
||||
if client_id in active_connections:
|
||||
del active_connections[client_id]
|
||||
|
||||
|
||||
@app.post("/asr")
|
||||
async def transcribe(audio: UploadFile = File(...)):
|
||||
if not audio:
|
||||
raise HTTPException(status_code=400, detail="未提供音频文件")
|
||||
|
||||
client_id = str(uuid.uuid4())
|
||||
|
||||
# 生成缓存键
|
||||
cache_key = f"asr:{audio.filename}:{client_id}"
|
||||
|
||||
# 检查缓存
|
||||
cached_result = redis_client.get(cache_key)
|
||||
if cached_result:
|
||||
logger.info(f"缓存命中: {cache_key}")
|
||||
return {"message": "从缓存获取转录结果", "transcription": cached_result.decode('utf-8'), "client_id": client_id}
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio:
|
||||
content = await audio.read()
|
||||
temp_audio.write(content)
|
||||
temp_audio.flush()
|
||||
|
||||
task = {
|
||||
'file_path': temp_audio.name,
|
||||
'client_id': client_id,
|
||||
'cache_key': cache_key
|
||||
}
|
||||
producer.send(KAFKA_TOPIC, value=task)
|
||||
producer.flush()
|
||||
|
||||
logger.info(f"发送任务到Kafka: {task}")
|
||||
return {"message": "音频文件已接收并发送任务进行处理", "client_id": client_id}
|
||||
|
||||
async def send_transcription(client_id: str, transcription: str):
|
||||
if client_id in active_connections:
|
||||
websocket = active_connections[client_id]
|
||||
await websocket.send_json({"transcription": transcription})
|
||||
else:
|
||||
logger.warning(f"客户端 {client_id} 的WebSocket连接不存在")
|
||||
|
||||
def kafka_consumer(consumer_id):
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
|
||||
group_id='asr_group',
|
||||
max_poll_interval_ms=300000
|
||||
)
|
||||
|
||||
for message in consumer:
|
||||
try:
|
||||
task = message.value
|
||||
file_path = task.get('file_path')
|
||||
client_id = task.get('client_id')
|
||||
cache_key = task.get('cache_key')
|
||||
|
||||
if not file_path or not client_id or not cache_key:
|
||||
logger.error(f"消费者 {consumer_id} 收到无效任务: {task}")
|
||||
consumer.commit()
|
||||
continue
|
||||
|
||||
result = model.transcribe(file_path)
|
||||
|
||||
logger.info(f"消费者 {consumer_id} 处理了文件: {file_path}")
|
||||
logger.info(f"转录结果: {result['text']}")
|
||||
|
||||
# 将结果存入Redis缓存
|
||||
redis_client.setex(cache_key, 3600, result['text']) # 缓存1小时
|
||||
|
||||
asyncio.run(send_transcription(client_id, result['text']))
|
||||
|
||||
os.remove(file_path)
|
||||
consumer.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"消费者 {consumer_id} 处理消息时发生错误: {str(e)}")
|
||||
|
||||
def start_consumers(num_consumers=1):
|
||||
for i in range(num_consumers):
|
||||
consumer_thread = threading.Thread(target=kafka_consumer, args=(i,))
|
||||
consumer_thread.start()
|
||||
|
||||
if __name__ == '__main__':
|
||||
start_consumers()
|
||||
uvicorn.run(app, host="0.0.0.0", port=6000)
|
||||
@@ -0,0 +1,122 @@
|
||||
from kafka import KafkaConsumer
|
||||
import json
|
||||
import asyncio
|
||||
import redis
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import requests
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
|
||||
# Kafka 设置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
|
||||
KAFKA_CONSUMER_GROUP = 'chat_group'
|
||||
KAFKA_CONSUMER_NUM = 1 # 消费者数量
|
||||
|
||||
# Redis 设置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_CHAT_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
redis_task_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_TASK_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
|
||||
|
||||
# 创建Kafka消费者
|
||||
def create_kafka_consumer():
|
||||
return KafkaConsumer(
|
||||
KAFKA_CHAT_TOPIC,
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
auto_offset_reset='latest',
|
||||
enable_auto_commit=True,
|
||||
group_id=KAFKA_CONSUMER_GROUP,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
async def process_chat_request(chat_request):
|
||||
try:
|
||||
task_id = chat_request['task_id']
|
||||
session_id = chat_request['session_id']
|
||||
query = chat_request['query']
|
||||
model = chat_request.get('model', 'qwen2.5:3b')
|
||||
|
||||
# 设置任务状态为 "processing"
|
||||
redis_task_client.set(f"chat:{task_id}:status", "processing")
|
||||
|
||||
# 从Redis获取历史记录 (使用 session_id)
|
||||
history = json.loads(redis_client.get(f"chat:{session_id}") or '[]')
|
||||
|
||||
# 构建包含历史对话的完整提示
|
||||
full_prompt = DEFAULT_SYSTEM_PROMPT + "\n"
|
||||
for past_query, past_response in history:
|
||||
full_prompt += f"用户: {past_query}\n助手: {past_response}\n"
|
||||
full_prompt += f"用户: {query}\n助手:"
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": full_prompt,
|
||||
"stream": True,
|
||||
"temperature": 0
|
||||
}
|
||||
|
||||
response = requests.post("https://ffgregevrdcfyhtnhyudvr.myfastools.com/api/generate", json=data, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
text_output = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
json_data = json.loads(line)
|
||||
if 'response' in json_data:
|
||||
text_output += json_data['response']
|
||||
|
||||
# 更新历史记录 (使用 session_id)
|
||||
history.append((query, text_output))
|
||||
redis_client.set(f"chat:{session_id}", json.dumps(history))
|
||||
|
||||
# 设置任务状态为 "completed" 并存储响应 (使用 task_id)
|
||||
redis_task_client.set(f"chat:{task_id}:status", "completed")
|
||||
redis_task_client.set(f"chat:{task_id}:response", text_output)
|
||||
|
||||
# 存储当前任务的结果到 REDIS_TASK_DB (db3)
|
||||
redis_task_client.set(f"chat:{task_id}:result", json.dumps({"query": query, "response": text_output}))
|
||||
|
||||
print(f"处理完成 task_id {task_id}, session_id {session_id}: {text_output}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理 task {task_id} 时出错: {str(e)}")
|
||||
# 设置任务状态为 "error"
|
||||
redis_task_client.set(f"chat:{task_id}:status", "error")
|
||||
redis_task_client.set(f"chat:{task_id}:error", str(e))
|
||||
def kafka_consumer_thread(consumer_id):
|
||||
consumer = create_kafka_consumer()
|
||||
print(f"消费者 {consumer_id} 已启动")
|
||||
for message in consumer:
|
||||
chat_request = message.value
|
||||
asyncio.run(process_chat_request(chat_request))
|
||||
|
||||
def main():
|
||||
print("启动Kafka消费者处理聊天请求...")
|
||||
with ThreadPoolExecutor(max_workers=KAFKA_CONSUMER_NUM) as executor:
|
||||
for i in range(KAFKA_CONSUMER_NUM):
|
||||
executor.submit(kafka_consumer_thread, i)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
from moviepy.editor import VideoFileClip
|
||||
|
||||
def mp4_to_wav(input_file, output_file):
|
||||
"""
|
||||
将MP4文件转换为WAV格式
|
||||
|
||||
:param input_file: 输入的MP4文件路径
|
||||
:param output_file: 输出的WAV文件路径
|
||||
"""
|
||||
try:
|
||||
# 加载视频文件
|
||||
video = VideoFileClip(input_file)
|
||||
|
||||
# 提取音频
|
||||
audio = video.audio
|
||||
|
||||
# 将音频写入WAV文件
|
||||
audio.write_audiofile(output_file)
|
||||
|
||||
# 关闭视频和音频对象
|
||||
audio.close()
|
||||
video.close()
|
||||
|
||||
print(f"转换成功: {input_file} -> {output_file}")
|
||||
except Exception as e:
|
||||
print(f"转换失败: {input_file} - {str(e)}")
|
||||
|
||||
def process_directory(directory):
|
||||
"""
|
||||
处理目录中的所有MP4文件
|
||||
|
||||
:param directory: 包含MP4文件的目录路径
|
||||
"""
|
||||
for filename in os.listdir(directory):
|
||||
if filename.lower().endswith('.mp4'):
|
||||
input_file = os.path.join(directory, filename)
|
||||
output_file = os.path.splitext(input_file)[0] + ".wav"
|
||||
mp4_to_wav(input_file, output_file)
|
||||
|
||||
def main():
|
||||
# 获取输入路径
|
||||
input_path = input("请输入MP4文件或包含MP4文件的目录路径: ").strip()
|
||||
|
||||
# 检查输入路径是否存在
|
||||
if not os.path.exists(input_path):
|
||||
print("错误: 输入路径不存在")
|
||||
return
|
||||
|
||||
# 判断输入路径是文件还是目录
|
||||
if os.path.isfile(input_path):
|
||||
if not input_path.lower().endswith('.mp4'):
|
||||
print("错误: 输入文件不是MP4格式")
|
||||
return
|
||||
output_file = os.path.splitext(input_path)[0] + ".wav"
|
||||
mp4_to_wav(input_path, output_file)
|
||||
elif os.path.isdir(input_path):
|
||||
process_directory(input_path)
|
||||
else:
|
||||
print("错误: 输入路径既不是文件也不是目录")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,136 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import httpx
|
||||
import json
|
||||
import redis
|
||||
from typing import List, Dict, Optional
|
||||
import logging
|
||||
import ollama
|
||||
import uuid
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Redis连接
|
||||
redis_client = redis.Redis(host='222.186.10.253', port=6379, db=14, password="Obscura@2024")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
model: Optional[str] = "qwen2.5:3b"
|
||||
prompt: str
|
||||
|
||||
class RawGenerateRequest(BaseModel):
|
||||
model: Optional[str] = "qwen2.5:3b"
|
||||
prompt: str
|
||||
system_prompt: Optional[str] = None
|
||||
stream: Optional[bool] = False
|
||||
raw: Optional[bool] = False
|
||||
format: Optional[str] = None
|
||||
options: Optional[Dict] = None
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
response: dict
|
||||
request_id: str
|
||||
|
||||
@app.post("/generate", response_model=GenerateResponse)
|
||||
async def generate(request: GenerateRequest):
|
||||
logger.info(f"收到请求: {request}")
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
response = ollama.chat(model=request.model, messages=[{"role": "user", "content": request.prompt}])
|
||||
full_response = response['message']['content']
|
||||
|
||||
request_data = {
|
||||
"model": request.model,
|
||||
"prompt": request.prompt,
|
||||
"response": full_response
|
||||
}
|
||||
|
||||
redis_client.set(f"request:{request_id}", json.dumps(request_data))
|
||||
|
||||
response_data = {
|
||||
"response": full_response,
|
||||
"model": request.model
|
||||
}
|
||||
|
||||
return GenerateResponse(response=response_data, request_id=request_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发生错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/generate")
|
||||
async def generate_without_history(request: RawGenerateRequest):
|
||||
"""
|
||||
处理无历史记录的生成请求。
|
||||
|
||||
参数:
|
||||
- request: RawGenerateRequest对象,包含生成请求的所有参数。
|
||||
|
||||
返回:
|
||||
- 包含生成结果的字典。
|
||||
"""
|
||||
try:
|
||||
response = ollama.generate(
|
||||
model=request.model,
|
||||
prompt=request.prompt,
|
||||
system=request.system_prompt,
|
||||
format=request.format,
|
||||
options=request.options,
|
||||
stream=request.stream
|
||||
)
|
||||
|
||||
response_data = {
|
||||
"model": request.model,
|
||||
"response": response['response'],
|
||||
"done": True,
|
||||
"context": response.get('context'),
|
||||
"total_duration": response.get('total_duration'),
|
||||
"load_duration": response.get('load_duration'),
|
||||
"prompt_eval_count": response.get('prompt_eval_count'),
|
||||
"prompt_eval_duration": response.get('prompt_eval_duration'),
|
||||
"eval_count": response.get('eval_count'),
|
||||
"eval_duration": response.get('eval_duration')
|
||||
}
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
redis_client.set(f"request:{request_id}", json.dumps(response_data))
|
||||
|
||||
return response_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发生未预期的错误: {e}")
|
||||
logger.exception("详细错误信息:")
|
||||
raise HTTPException(status_code=500, detail=f"处理Ollama请求时发生错误: {str(e)}")
|
||||
|
||||
@app.get("/request/{request_id}", response_model=Dict)
|
||||
async def get_request(request_id: str):
|
||||
request_data = redis_client.get(f"request:{request_id}")
|
||||
if request_data:
|
||||
return json.loads(request_data)
|
||||
raise HTTPException(status_code=404, detail="请求未找到")
|
||||
|
||||
@app.get("/models")
|
||||
async def list_models():
|
||||
return ollama.list()
|
||||
|
||||
@app.get("/models/{model_name}")
|
||||
async def show_model(model_name: str):
|
||||
return ollama.show(model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=7000)
|
||||
@@ -0,0 +1,406 @@
|
||||
from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security import APIKeyHeader
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
from kafka import KafkaProducer
|
||||
from redis import Redis
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from dotenv import load_dotenv
|
||||
import tempfile
|
||||
import hashlib
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# 在文件顶部添加这个函数
|
||||
def get_audio_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI()
|
||||
v1_chat_app = FastAPI()
|
||||
app.mount("/v1_chat", v1_chat_app)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB'))
|
||||
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
|
||||
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
|
||||
REDIS_API_DB = int(os.getenv('REDIS_API_DB'))
|
||||
REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB'))
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||
|
||||
|
||||
# Redis 配置
|
||||
REDIS_GIRL_DB = int(os.getenv('REDIS_GIRL_DB'))
|
||||
REDIS_WOMAN_DB = int(os.getenv('REDIS_WOMAN_DB'))
|
||||
REDIS_MAN_DB = int(os.getenv('REDIS_MAN_DB'))
|
||||
REDIS_LEIJUN_DB = int(os.getenv('REDIS_LEIJUN_DB'))
|
||||
REDIS_DUFU_DB = int(os.getenv('REDIS_DUFU_DB'))
|
||||
REDIS_HEJIONG_DB = int(os.getenv('REDIS_HEJIONG_DB'))
|
||||
REDIS_MAHUATENG_DB = int(os.getenv('REDIS_MAHUATENG_DB'))
|
||||
REDIS_LIDAN_DB = int(os.getenv('REDIS_LIDAN_DB'))
|
||||
REDIS_YUHUA_DB = int(os.getenv('REDIS_YUHUA_DB'))
|
||||
REDIS_LIUZHENYUN_DB = int(os.getenv('REDIS_LIUZHENYUN_DB'))
|
||||
REDIS_DABING_DB = int(os.getenv('REDIS_DABING_DB'))
|
||||
REDIS_LUOXIANG_DB = int(os.getenv('REDIS_LUOXIANG_DB'))
|
||||
REDIS_XUZHIYUAN_DB = int(os.getenv('REDIS_XUZHIYUAN_DB'))
|
||||
|
||||
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||
KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
||||
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
|
||||
|
||||
OUTPUT_PATH= os.getenv('OUTPUT_PATH')
|
||||
|
||||
# 初始化 Kafka Producer
|
||||
producer = KafkaProducer(
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB)
|
||||
redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB)
|
||||
redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB)
|
||||
redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB)
|
||||
redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB)
|
||||
redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB)
|
||||
|
||||
redis_tts_girl = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_GIRL_DB)
|
||||
redis_tts_woman = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_WOMAN_DB)
|
||||
redis_tts_man = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAN_DB)
|
||||
redis_tts_leijun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LEIJUN_DB)
|
||||
redis_tts_dufu = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DUFU_DB)
|
||||
redis_tts_hejiong = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_HEJIONG_DB)
|
||||
redis_tts_mahuateng = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAHUATENG_DB)
|
||||
redis_tts_lidan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIDAN_DB)
|
||||
redis_tts_yuhua = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_YUHUA_DB)
|
||||
redis_tts_liuzhenyun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIUZHENYUN_DB)
|
||||
redis_tts_dabing = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DABING_DB)
|
||||
redis_tts_luoxiang = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LUOXIANG_DB)
|
||||
redis_tts_xuzhiyuan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_XUZHIYUAN_DB)
|
||||
|
||||
# 创建一个音色到对应 Redis 客户端的映射
|
||||
voice_to_redis = {
|
||||
"default": redis_tts_girl,
|
||||
"girl": redis_tts_girl,
|
||||
"woman": redis_tts_woman,
|
||||
"man": redis_tts_man,
|
||||
"leijun": redis_tts_leijun,
|
||||
"dufu": redis_tts_dufu,
|
||||
"hejiong": redis_tts_hejiong,
|
||||
"mahuateng": redis_tts_mahuateng,
|
||||
"lidan": redis_tts_lidan,
|
||||
"yuhua": redis_tts_yuhua,
|
||||
"liuzhenyun": redis_tts_liuzhenyun,
|
||||
"dabing": redis_tts_dabing,
|
||||
"luoxiang": redis_tts_luoxiang,
|
||||
"xuzhiyuan": redis_tts_xuzhiyuan
|
||||
}
|
||||
|
||||
|
||||
# 定义API密钥头部
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
def get_audio_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# 验证API密钥的函数
|
||||
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||
if api_key and api_key.startswith("Bearer "):
|
||||
key = api_key.split(" ")[1]
|
||||
if key.startswith("obs-"):
|
||||
return key
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的API密钥",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
"api_key": api_key,
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str):
|
||||
api_key = api_key_info['api_key']
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
# 发送任务到Kafka
|
||||
producer.send(kafka_topic, task_data)
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return {
|
||||
"message": f"{model_name.upper()}请求已排队等待处理",
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
}
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str
|
||||
voice: str = Field(..., description="选择的音色")
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str
|
||||
query: str
|
||||
model: str = "qwen2.5:3b"
|
||||
|
||||
|
||||
@v1_chat_app.post("/tts")
|
||||
async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
text_hash = get_audio_hash(request.text)
|
||||
|
||||
# 验证音色选择
|
||||
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
|
||||
if request.voice not in valid_voices:
|
||||
raise HTTPException(status_code=400, detail="无效的音色选择")
|
||||
|
||||
# 如果声音是 'default',则将其视为 'girl'
|
||||
voice = 'girl' if request.voice == 'default' else request.voice
|
||||
|
||||
# 使用对应音色的 Redis 客户端
|
||||
redis_tts = voice_to_redis[request.voice]
|
||||
|
||||
# 检查是否已存在相同内容的音频文件
|
||||
existing_audio_info = redis_tts.get(f"tts:{text_hash}")
|
||||
if existing_audio_info:
|
||||
existing_audio_path = json.loads(existing_audio_info)['path']
|
||||
if os.path.exists(existing_audio_path):
|
||||
return {
|
||||
"message": "TTS请求已完成",
|
||||
"task_id": task_id,
|
||||
"status": "completed",
|
||||
"audio_path": existing_audio_path
|
||||
}
|
||||
|
||||
# 如果不存在,创建新的任务
|
||||
task_data = {
|
||||
"task_id": task_id,
|
||||
"text": request.text,
|
||||
"text_hash": text_hash,
|
||||
"voice": request.voice,
|
||||
"status": "queued",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# 存储任务信息到Redis
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "tts", 1, task_data, KAFKA_TTS_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
return result
|
||||
@v1_chat_app.post("/asr")
|
||||
async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
UPLOAD_DIR = "/obscura/task/audio_upload"
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav")
|
||||
|
||||
with open(file_path, "wb") as temp_audio:
|
||||
content = await audio.read()
|
||||
temp_audio.write(content)
|
||||
|
||||
task_data = {
|
||||
'file_path': file_path,
|
||||
'task_id': task_id,
|
||||
'status': 'queued'
|
||||
}
|
||||
|
||||
# 存储任务状态,使用一致的键名格式
|
||||
redis_task_client.set(f"task_status:asr:{task_id}", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "asr", 1, task_data, KAFKA_ASR_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
return result
|
||||
|
||||
@v1_chat_app.post("/chat")
|
||||
async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
task_data = {
|
||||
"task_id": task_id,
|
||||
"session_id": request.session_id,
|
||||
"query": request.query,
|
||||
"model": request.model,
|
||||
"status": "queued",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# 设置任务状态为 "queued"
|
||||
redis_task_client.set(f"chat:{task_id}:status", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "chat", 1, task_data, KAFKA_CHAT_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
return result
|
||||
|
||||
|
||||
@v1_chat_app.get("/chat_result/{task_id}")
|
||||
async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
# 从Redis任务数据库获取任务状态
|
||||
task_status = redis_task_client.get(f"chat:{task_id}:status")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从Redis任务数据库获取聊天结果
|
||||
chat_result = redis_task_client.get(f"chat:{task_id}:result")
|
||||
if chat_result:
|
||||
result = json.loads(chat_result)
|
||||
return {
|
||||
"status": "completed",
|
||||
"result": result
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
@v1_chat_app.get("/tts_result/{task_id}")
|
||||
async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
|
||||
if task_info:
|
||||
task_data = json.loads(task_info)
|
||||
text_hash = task_data['text_hash']
|
||||
voice = task_data['voice']
|
||||
# 'default' 和 'girl' 都使用 girl 的 Redis
|
||||
redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice]
|
||||
|
||||
audio_info = redis_tts.get(f"tts:{text_hash}")
|
||||
if audio_info:
|
||||
audio_path = json.loads(audio_info)['path']
|
||||
return {
|
||||
"status": "completed",
|
||||
"audio_path": audio_path
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
@v1_chat_app.get("/asr_result/{task_id}")
|
||||
async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
# 从Redis任务数据库获取任务状态,使用一致的键名格式
|
||||
task_status = redis_task_client.get(f"task_status:asr:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从Redis ASR结果数据库获取转录结果
|
||||
transcription = redis_asr_client.get(f"asr:{task_id}")
|
||||
if transcription:
|
||||
return {
|
||||
"status": "completed",
|
||||
"transcription": transcription.decode('utf-8')
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
@v1_chat_app.get("/tts_audio/{task_id}")
|
||||
async def get_tts_audio(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从任务信息中获取使用的音色
|
||||
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
|
||||
if task_info:
|
||||
task_data = json.loads(task_info)
|
||||
voice = task_data.get('voice', 'girl') # 默认使用 'girl'
|
||||
# 'default' 和 'girl' 都使用 girl 的 Redis
|
||||
redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice]
|
||||
|
||||
# 从对应音色的 Redis 获取音频文件路径
|
||||
audio_info = redis_tts.get(f"tts:{task_data['text_hash']}")
|
||||
if audio_info:
|
||||
audio_path = json.loads(audio_info)['path']
|
||||
if os.path.exists(audio_path):
|
||||
file_name = os.path.basename(audio_path)
|
||||
return FileResponse(audio_path, media_type="audio/wav", filename=file_name)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="音频文件不存在")
|
||||
elif status == "queued" or status == "processing":
|
||||
raise HTTPException(status_code=202, detail="音频文件正在生成中")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="任务处理失败")
|
||||
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
@v1_chat_app.get("/getvoice")
|
||||
async def get_available_voices(api_key_info: dict = Depends(verify_api_key)):
|
||||
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
|
||||
return {"available_voices": valid_voices}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8008)
|
||||
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
那一年济南冬天零下十几度
|
||||
我暖气费交不上
|
||||
因为从那年往前推好几年
|
||||
我不接任何的商业了早就
|
||||
因为不愿意再唱唐会了嘛
|
||||
Binary file not shown.
@@ -0,0 +1,4 @@
|
||||
金无足赤人无完人嘛
|
||||
连朕也出过错误
|
||||
就说这黄羽全图吧
|
||||
朕每次见着他
|
||||
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
一些研究表明一個晚上良好的睡眠
|
||||
就能幫助大腦恢復到最佳狀態
|
||||
所以如果你已經一週
|
||||
Binary file not shown.
@@ -0,0 +1,4 @@
|
||||
很多年前我是主持人
|
||||
做音乐节目
|
||||
然后当时我们节目敲了当年最红的一个歌手
|
||||
叫陈冠希
|
||||
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
我们在短短的半年之间的时间里面
|
||||
就组成了超过一千人的团队
|
||||
在过去三年多的时间里面
|
||||
Binary file not shown.
@@ -0,0 +1,4 @@
|
||||
他去那个商场
|
||||
两口子去逛 买电视
|
||||
大家知道现在的智能电视那个遥控器都有一个语音搜索功能
|
||||
年轻人不怎么用其实
|
||||
Binary file not shown.
@@ -0,0 +1,5 @@
|
||||
在我们村里
|
||||
最有见识的人呢
|
||||
是我舅
|
||||
他是个赶马车的
|
||||
他不但去过县城
|
||||
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
所以无论是中外
|
||||
最大限度的反抗标准
|
||||
在很长一段时间都是一种主流立场
|
||||
Binary file not shown.
@@ -0,0 +1,2 @@
|
||||
那么今天呢 政治主持人介绍是我们第四次的互联网家峰会
|
||||
那么这次的规模是世界以来规模最大的
|
||||
Binary file not shown.
@@ -0,0 +1,2 @@
|
||||
今年以來 我國全力推動鄉村產業全鏈條升級
|
||||
鄉村產業振興呈現良好勢頭
|
||||
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
政法機關堅持黨對政法工作的絕對領導
|
||||
推動政法體制和工作機制
|
||||
實現歷史性變革
|
||||
Binary file not shown.
@@ -0,0 +1,5 @@
|
||||
我们约了他去做他的采访
|
||||
他已经答应了
|
||||
然后结果去那天
|
||||
他那时候已经得了白血病了
|
||||
他说真是不巧不好意思
|
||||
Binary file not shown.
@@ -0,0 +1,5 @@
|
||||
很小很小的房间
|
||||
来人的话呢
|
||||
如果一个人说要外出去上厕所
|
||||
因为厕所是公共厕所
|
||||
所有人都得起来走到外面去
|
||||
Binary file not shown.
+315
@@ -0,0 +1,315 @@
|
||||
import os
|
||||
import soundfile as sf
|
||||
import redis
|
||||
import hashlib
|
||||
import json
|
||||
import traceback
|
||||
from kafka import KafkaConsumer
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
from dotenv import load_dotenv
|
||||
import torch
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
# Redis 配置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) # DB 3
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# Kafka 配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||
|
||||
# TTS 配置
|
||||
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
|
||||
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
|
||||
REF_LANGUAGE = os.getenv('REF_LANGUAGE')
|
||||
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
|
||||
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
|
||||
|
||||
# Redis 配置
|
||||
REDIS_GIRL_DB = int(os.getenv('REDIS_GIRL_DB'))
|
||||
REDIS_WOMAN_DB = int(os.getenv('REDIS_WOMAN_DB'))
|
||||
REDIS_MAN_DB = int(os.getenv('REDIS_MAN_DB'))
|
||||
REDIS_LEIJUN_DB = int(os.getenv('REDIS_LEIJUN_DB'))
|
||||
REDIS_DUFU_DB = int(os.getenv('REDIS_DUFU_DB'))
|
||||
REDIS_HEJIONG_DB = int(os.getenv('REDIS_HEJIONG_DB'))
|
||||
REDIS_MAHUATENG_DB = int(os.getenv('REDIS_MAHUATENG_DB'))
|
||||
REDIS_LIDAN_DB = int(os.getenv('REDIS_LIDAN_DB'))
|
||||
REDIS_YUHUA_DB = int(os.getenv('REDIS_YUHUA_DB'))
|
||||
REDIS_LIUZHENYUN_DB = int(os.getenv('REDIS_LIUZHENYUN_DB'))
|
||||
REDIS_DABING_DB = int(os.getenv('REDIS_DABING_DB'))
|
||||
REDIS_LUOXIANG_DB = int(os.getenv('REDIS_LUOXIANG_DB'))
|
||||
REDIS_XUZHIYUAN_DB = int(os.getenv('REDIS_XUZHIYUAN_DB'))
|
||||
|
||||
# 初始化 Redis 客户端
|
||||
redis_tts_girl = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_GIRL_DB)
|
||||
redis_tts_woman = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_WOMAN_DB)
|
||||
redis_tts_man = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAN_DB)
|
||||
redis_tts_leijun = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LEIJUN_DB)
|
||||
redis_tts_dufu = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DUFU_DB)
|
||||
redis_tts_hejiong = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_HEJIONG_DB)
|
||||
redis_tts_mahuateng = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAHUATENG_DB)
|
||||
redis_tts_lidan = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIDAN_DB)
|
||||
redis_tts_yuhua = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_YUHUA_DB)
|
||||
redis_tts_liuzhenyun = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIUZHENYUN_DB)
|
||||
redis_tts_dabing = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DABING_DB)
|
||||
redis_tts_luoxiang = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LUOXIANG_DB)
|
||||
redis_tts_xuzhiyuan = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_XUZHIYUAN_DB)
|
||||
|
||||
redis_task_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_TASK_DB, password=REDIS_PASSWORD)
|
||||
|
||||
# 创建音色到对应 Redis 客户端的映射
|
||||
voice_to_redis = {
|
||||
"default": redis_tts_girl,
|
||||
"girl": redis_tts_girl,
|
||||
"woman": redis_tts_woman,
|
||||
"man": redis_tts_man,
|
||||
"leijun": redis_tts_leijun,
|
||||
"dufu": redis_tts_dufu,
|
||||
"hejiong": redis_tts_hejiong,
|
||||
"mahuateng": redis_tts_mahuateng,
|
||||
"lidan": redis_tts_lidan,
|
||||
"yuhua": redis_tts_yuhua,
|
||||
"liuzhenyun": redis_tts_liuzhenyun,
|
||||
"dabing": redis_tts_dabing,
|
||||
"luoxiang": redis_tts_luoxiang,
|
||||
"xuzhiyuan": redis_tts_xuzhiyuan
|
||||
}
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
# Voice configurations
|
||||
GIRL_REF_AUDIO = os.getenv('GIRL_REF_AUDIO')
|
||||
GIRL_REF_TEXT = os.getenv('GIRL_REF_TEXT')
|
||||
WOMAN_REF_AUDIO = os.getenv('WOMAN_REF_AUDIO')
|
||||
WOMAN_REF_TEXT = os.getenv('WOMAN_REF_TEXT')
|
||||
MAN_REF_AUDIO = os.getenv('MAN_REF_AUDIO')
|
||||
MAN_REF_TEXT = os.getenv('MAN_REF_TEXT')
|
||||
LEIJUN_REF_AUDIO = os.getenv('LEIJUN_REF_AUDIO')
|
||||
LEIJUN_REF_TEXT = os.getenv('LEIJUN_REF_TEXT')
|
||||
DUFU_REF_AUDIO = os.getenv('DUFU_REF_AUDIO')
|
||||
DUFU_REF_TEXT = os.getenv('DUFU_REF_TEXT')
|
||||
HEJIONG_REF_AUDIO = os.getenv('HEJIONG_REF_AUDIO')
|
||||
HEJIONG_REF_TEXT = os.getenv('HEJIONG_REF_TEXT')
|
||||
MAHUATENG_REF_AUDIO = os.getenv('MAHUATENG_REF_AUDIO')
|
||||
MAHUATENG_REF_TEXT = os.getenv('MAHUATENG_REF_TEXT')
|
||||
LIDAN_REF_AUDIO = os.getenv('LIDAN_REF_AUDIO')
|
||||
LIDAN_REF_TEXT = os.getenv('LIDAN_REF_TEXT')
|
||||
YUHUA_REF_AUDIO = os.getenv('YUHUA_REF_AUDIO')
|
||||
YUHUA_REF_TEXT = os.getenv('YUHUA_REF_TEXT')
|
||||
LIUZHENYUN_REF_AUDIO = os.getenv('LIUZHENYUN_REF_AUDIO')
|
||||
LIUZHENYUN_REF_TEXT = os.getenv('LIUZHENYUN_REF_TEXT')
|
||||
DABING_REF_AUDIO = os.getenv('DABING_REF_AUDIO')
|
||||
DABING_REF_TEXT = os.getenv('DABING_REF_TEXT')
|
||||
LUOXIANG_REF_AUDIO = os.getenv('LUOXIANG_REF_AUDIO')
|
||||
LUOXIANG_REF_TEXT = os.getenv('LUOXIANG_REF_TEXT')
|
||||
XUZHIYUAN_REF_AUDIO = os.getenv('XUZHIYUAN_REF_AUDIO')
|
||||
XUZHIYUAN_REF_TEXT = os.getenv('XUZHIYUAN_REF_TEXT')
|
||||
|
||||
VOICE_CONFIGS = {
|
||||
"girl": {
|
||||
"ref_audio": GIRL_REF_AUDIO,
|
||||
"ref_text": GIRL_REF_TEXT,
|
||||
"ref_language": REF_LANGUAGE
|
||||
},
|
||||
"woman": {
|
||||
"ref_audio": WOMAN_REF_AUDIO,
|
||||
"ref_text": WOMAN_REF_TEXT,
|
||||
"ref_language": REF_LANGUAGE
|
||||
},
|
||||
"man": {
|
||||
"ref_audio": MAN_REF_AUDIO,
|
||||
"ref_text": MAN_REF_TEXT,
|
||||
"ref_language": REF_LANGUAGE
|
||||
},
|
||||
"leijun": {
|
||||
"ref_audio": LEIJUN_REF_AUDIO,
|
||||
"ref_text": LEIJUN_REF_TEXT,
|
||||
"ref_language": REF_LANGUAGE
|
||||
},
|
||||
"dufu": {
|
||||
"ref_audio": DUFU_REF_AUDIO,
|
||||
"ref_text": DUFU_REF_TEXT,
|
||||
"ref_language": REF_LANGUAGE
|
||||
},
|
||||
"hejiong": {
|
||||
"ref_audio": HEJIONG_REF_AUDIO,
|
||||
"ref_text": HEJIONG_REF_TEXT,
|
||||
"ref_language": REF_LANGUAGE
|
||||
},
|
||||
"mahuateng": {
|
||||
"ref_audio": MAHUATENG_REF_AUDIO,
|
||||
"ref_text": MAHUATENG_REF_TEXT,
|
||||
"ref_language": REF_LANGUAGE
|
||||
},
|
||||
"lidan": {
|
||||
"ref_audio": LIDAN_REF_AUDIO,
|
||||
"ref_text": LIDAN_REF_TEXT,
|
||||
"ref_language": REF_LANGUAGE
|
||||
},
|
||||
"default": {
|
||||
"ref_audio": GIRL_REF_AUDIO,
|
||||
"ref_text": GIRL_REF_TEXT,
|
||||
"ref_language": REF_LANGUAGE
|
||||
},
|
||||
"yuhua": {
|
||||
"ref_audio": YUHUA_REF_AUDIO,
|
||||
"ref_text": YUHUA_REF_TEXT,
|
||||
"ref_language": REF_LANGUAGE
|
||||
},
|
||||
"liuzhenyun": {
|
||||
"ref_audio": LIUZHENYUN_REF_AUDIO,
|
||||
"ref_text": LIUZHENYUN_REF_TEXT,
|
||||
"ref_language": REF_LANGUAGE
|
||||
},
|
||||
"dabing": {
|
||||
"ref_audio": DABING_REF_AUDIO,
|
||||
"ref_text": DABING_REF_TEXT,
|
||||
"ref_language": REF_LANGUAGE
|
||||
},
|
||||
"luoxiang": {
|
||||
"ref_audio": LUOXIANG_REF_AUDIO,
|
||||
"ref_text": LUOXIANG_REF_TEXT,
|
||||
"ref_language": REF_LANGUAGE
|
||||
},
|
||||
"xuzhiyuan": {
|
||||
"ref_audio": XUZHIYUAN_REF_AUDIO,
|
||||
"ref_text": XUZHIYUAN_REF_TEXT,
|
||||
"ref_language": REF_LANGUAGE
|
||||
}
|
||||
}
|
||||
|
||||
def get_audio_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# 在启动时初始化模型
|
||||
print("正在初始化模型...")
|
||||
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
|
||||
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
|
||||
print("模型初始化成功。")
|
||||
|
||||
def read_ref_text(voice_type):
|
||||
ref_text_path = VOICE_CONFIGS[voice_type]["ref_text"]
|
||||
ref_text = ""
|
||||
try:
|
||||
if os.path.exists(ref_text_path):
|
||||
with open(ref_text_path, 'r', encoding='utf-8') as file:
|
||||
ref_text = file.read()
|
||||
else:
|
||||
print(f"警告:{voice_type} 的参考文本文件 '{ref_text_path}' 不存在。")
|
||||
except IOError as e:
|
||||
print(f"错误:无法读取 {voice_type} 的参考文本文件 '{ref_text_path}'。{str(e)}")
|
||||
return ref_text
|
||||
|
||||
def synthesize(target_text, output_wav_path, voice):
|
||||
voice_config = VOICE_CONFIGS[voice]
|
||||
ref_audio_path = voice_config["ref_audio"]
|
||||
|
||||
with open(voice_config["ref_text"], 'r', encoding='utf-8') as file:
|
||||
ref_text = file.read()
|
||||
|
||||
with torch.cuda.device(device):
|
||||
synthesis_result = get_tts_wav(
|
||||
ref_wav_path=ref_audio_path,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(voice_config["ref_language"]),
|
||||
text=target_text,
|
||||
text_language=i18n(TARGET_LANGUAGE),
|
||||
top_p=1,
|
||||
temperature=1
|
||||
)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
return output_wav_path
|
||||
else:
|
||||
return None
|
||||
|
||||
def kafka_consumer():
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TTS_TOPIC,
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
auto_offset_reset='latest',
|
||||
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
|
||||
)
|
||||
print(f"TTS消费者已启动")
|
||||
for message in consumer:
|
||||
task_id = None
|
||||
error_occurred = False # 将这行移到循环的开始
|
||||
try:
|
||||
task_id = message.value['task_id']
|
||||
target_text = message.value['text']
|
||||
text_hash = message.value['text_hash']
|
||||
voice = message.value.get('voice', 'default')
|
||||
|
||||
if voice == 'default':
|
||||
voice = 'girl'
|
||||
if voice not in VOICE_CONFIGS:
|
||||
print(f"警告:无效的音色类型 '{voice}'。使用默认音色。")
|
||||
voice = "girl"
|
||||
|
||||
# 更新任务状态为 "processing"
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "processing")
|
||||
|
||||
# 使用对应音色的 Redis 客户端
|
||||
redis_tts = voice_to_redis[voice]
|
||||
|
||||
# 检查是否已存在相同内容的音频文件
|
||||
existing_audio_info = redis_tts.get(f"tts:{text_hash}")
|
||||
if existing_audio_info:
|
||||
existing_audio_path = json.loads(existing_audio_info)['path']
|
||||
if os.path.exists(existing_audio_path):
|
||||
# 如果文件已存在,直接使用现有文件
|
||||
output_path = existing_audio_path
|
||||
else:
|
||||
# 如果文件不存在,重新生成
|
||||
output_wav_path = os.path.join(OUTPUT_PATH, f"{text_hash}_{voice}.wav")
|
||||
output_path = synthesize(target_text, output_wav_path, voice)
|
||||
else:
|
||||
# 如果不存在,创建新的音频文件
|
||||
output_wav_path = os.path.join(OUTPUT_PATH, f"{text_hash}_{voice}.wav")
|
||||
output_path = synthesize(target_text, output_wav_path, voice)
|
||||
|
||||
if output_path:
|
||||
# 将结果保存在对应音色的 Redis 中
|
||||
redis_tts.set(f"tts:{text_hash}", json.dumps({"path": output_path}))
|
||||
print(f"音频合成成功: {output_path}")
|
||||
|
||||
# 更新任务状态为 "completed"
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "completed")
|
||||
|
||||
# 存储任务信息
|
||||
redis_task_client.set(f"task_info:tts:{task_id}", json.dumps({
|
||||
"text_hash": text_hash,
|
||||
"voice": voice
|
||||
}))
|
||||
else:
|
||||
print("音频合成失败")
|
||||
error_occurred = True
|
||||
|
||||
except KeyError as e:
|
||||
print(f"错误:消息中缺少必要的键: {e}")
|
||||
error_occurred = True
|
||||
except Exception as e:
|
||||
print(f"处理消息时出错: {str(e)}")
|
||||
print(traceback.format_exc())
|
||||
error_occurred = True
|
||||
finally:
|
||||
if error_occurred:
|
||||
print("处理消息时发生错误")
|
||||
if task_id:
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "failed")
|
||||
else:
|
||||
print("消息处理完成")
|
||||
if __name__ == "__main__":
|
||||
torch.cuda.set_device(device)
|
||||
kafka_consumer()
|
||||
@@ -0,0 +1,68 @@
|
||||
import os
|
||||
import whisper
|
||||
import argparse
|
||||
|
||||
def transcribe_audio(model, audio_path):
|
||||
"""
|
||||
使用Whisper模型转录音频文件
|
||||
|
||||
:param model: 加载的Whisper模型
|
||||
:param audio_path: 音频文件路径
|
||||
:return: 转录的文本
|
||||
"""
|
||||
try:
|
||||
result = model.transcribe(audio_path)
|
||||
return result["text"]
|
||||
except Exception as e:
|
||||
print(f"转录失败 {audio_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def process_directory(directory, model):
|
||||
"""
|
||||
处理目录中的所有WAV文件
|
||||
|
||||
:param directory: 包含WAV文件的目录路径
|
||||
:param model: 加载的Whisper模型
|
||||
"""
|
||||
for filename in os.listdir(directory):
|
||||
if filename.lower().endswith('.wav'):
|
||||
input_file = os.path.join(directory, filename)
|
||||
output_file = os.path.splitext(input_file)[0] + ".txt"
|
||||
|
||||
print(f"正在处理: {input_file}")
|
||||
transcription = transcribe_audio(model, input_file)
|
||||
|
||||
if transcription:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(transcription)
|
||||
print(f"转录完成: {output_file}")
|
||||
else:
|
||||
print(f"转录失败: {input_file}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="使用Whisper将WAV文件转换为文本")
|
||||
parser.add_argument("input_path", help="输入的WAV文件或包含WAV文件的目录路径")
|
||||
parser.add_argument("--model", default="small", choices=["tiny", "base", "small", "medium", "large", "large-v3"], help="Whisper模型大小")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"正在加载Whisper模型 ({args.model})...")
|
||||
model = whisper.load_model(args.model)
|
||||
print("模型加载完成")
|
||||
|
||||
if os.path.isfile(args.input_path):
|
||||
if not args.input_path.lower().endswith('.wav'):
|
||||
print("错误: 输入文件不是WAV格式")
|
||||
return
|
||||
output_file = os.path.splitext(args.input_path)[0] + ".txt"
|
||||
transcription = transcribe_audio(model, args.input_path)
|
||||
if transcription:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(transcription)
|
||||
print(f"转录完成: {output_file}")
|
||||
elif os.path.isdir(args.input_path):
|
||||
process_directory(args.input_path, model)
|
||||
else:
|
||||
print("错误: 输入路径既不是文件也不是目录")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1 @@
|
||||
{"GPT": {"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"}, "SoVITS": {"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"}}
|
||||
Reference in New Issue
Block a user