update 20250403
This commit is contained in:
+30
-29
@@ -1,28 +1,27 @@
|
||||
# Kafka 配置
|
||||
KAFKA_BROKER=222.186.10.253:9092
|
||||
KAFKA_BROKER=222.186.20.67:9092
|
||||
KAFKA_ASR_TOPIC=asr
|
||||
KAFKA_CHAT_TOPIC=chat
|
||||
KAFKA_TTS_TOPIC=tts
|
||||
|
||||
|
||||
# Redis 配置
|
||||
REDIS_HOST=150.158.144.159
|
||||
REDIS_PORT=13003
|
||||
REDIS_ASR_DB=12
|
||||
REDIS_CHAT_DB=13
|
||||
REDIS_TTS_DB=14
|
||||
REDIS_HOST=222.186.20.67
|
||||
REDIS_PORT=6379
|
||||
REDIS_ASR_DB=43
|
||||
REDIS_CHAT_DB=44
|
||||
REDIS_TTS_DB=45
|
||||
REDIS_PASSWORD=Obscura@2024
|
||||
REDIS_API_DB=2
|
||||
REDIS_API_USAGE_DB=3
|
||||
REDIS_TASK_DB=11
|
||||
REDIS_SESSION_DB=63
|
||||
REDIS_API_DB=31
|
||||
REDIS_API_USAGE_DB=32
|
||||
REDIS_TASK_DB=46
|
||||
REDIS_SESSION_DB=47
|
||||
|
||||
REDIS_SESSION_DB_ZH=48
|
||||
REDIS_SESSION_DB_EN=49
|
||||
REDIS_SESSION_DB_KO=50
|
||||
|
||||
REDIS_SESSION_DB_ZH=63
|
||||
REDIS_SESSION_DB_EN=62
|
||||
REDIS_SESSION_DB_KO=61
|
||||
|
||||
# CORS 配置
|
||||
# ALLOWED_ORIGINS=https://beta.obscura.work
|
||||
|
||||
|
||||
# GPT-SoVITS 配置
|
||||
@@ -76,19 +75,21 @@ XUZHIYUAN_REF_AUDIO=sample/xuzhiyuan.wav
|
||||
XUZHIYUAN_REF_TEXT=sample/xuzhiyuan.txt
|
||||
|
||||
|
||||
REDIS_GIRL_DB = 15
|
||||
REDIS_WOMAN_DB = 16
|
||||
REDIS_MAN_DB = 17
|
||||
REDIS_LEIJUN_DB = 18
|
||||
REDIS_DUFU_DB = 19
|
||||
REDIS_HEJIONG_DB = 20
|
||||
REDIS_MAHUATENG_DB = 21
|
||||
REDIS_LIDAN_DB = 22
|
||||
REDIS_DABING_DB = 23
|
||||
REDIS_LUOXIANG_DB = 24
|
||||
REDIS_XUZHIYUAN_DB = 25
|
||||
REDIS_YUHUA_DB = 26
|
||||
REDIS_LIUZHENYUN_DB = 27
|
||||
|
||||
REDIS_GIRL_DB = 51
|
||||
REDIS_WOMAN_DB = 52
|
||||
REDIS_MAN_DB = 53
|
||||
REDIS_LEIJUN_DB = 54
|
||||
REDIS_DUFU_DB = 55
|
||||
REDIS_HEJIONG_DB = 56
|
||||
REDIS_MAHUATENG_DB = 57
|
||||
REDIS_LIDAN_DB = 58
|
||||
REDIS_DABING_DB = 59
|
||||
REDIS_LUOXIANG_DB = 60
|
||||
REDIS_XUZHIYUAN_DB = 61
|
||||
REDIS_YUHUA_DB = 62
|
||||
REDIS_LIUZHENYUN_DB = 63
|
||||
|
||||
|
||||
# Ollama API配置 - 多个地址用逗号分隔
|
||||
OLLAMA_URLS=http://222.186.20.67:11435,http://222.186.20.67:11436,http://222.186.20.67:11437,http://222.186.20.67:11438,http://222.186.20.67:11439,http://222.186.20.67:11440,http://222.186.20.67:11441
|
||||
OLLAMA_TIMEOUT=10 # API请求超时时间(秒)
|
||||
Regular → Executable
+4
-1
@@ -6,6 +6,9 @@ from dotenv import load_dotenv
|
||||
from kafka import KafkaConsumer
|
||||
import asyncio
|
||||
|
||||
# 在导入其他库之前设置
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# 设置要使用的GPU ID
|
||||
GPU_ID = 1 # 修改这个值来选择要使用的GPU
|
||||
|
||||
@@ -16,7 +19,7 @@ os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
|
||||
load_dotenv()
|
||||
|
||||
print("正在加载Whisper模型...")
|
||||
model = whisper.load_model("large-v3")
|
||||
model = whisper.load_model("small")
|
||||
print("Whisper模型加载完成。")
|
||||
|
||||
# Kafka配置
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
import whisper
|
||||
import os
|
||||
import json
|
||||
import redis
|
||||
from dotenv import load_dotenv
|
||||
from kafka import KafkaConsumer
|
||||
import threading
|
||||
|
||||
# 设置要使用的GPU ID
|
||||
GPU_ID = 1 # 修改这个值来选择要使用的GPU
|
||||
|
||||
# 设置CUDA_VISIBLE_DEVICES环境变量
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
print("正在加载Whisper模型...")
|
||||
model = whisper.load_model("large-v3")
|
||||
print("Whisper模型加载完成。")
|
||||
|
||||
# Kafka配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
||||
|
||||
# Redis配置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_asr_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_ASR_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
redis_task_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_TASK_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
def process_audio(file_path: str, client_id: str, cache_key: str):
|
||||
try:
|
||||
# 设置任务状态为 "processing"
|
||||
redis_task_client.set(f"task_status:{cache_key}", "processing")
|
||||
|
||||
result = model.transcribe(file_path)
|
||||
transcription = result['text']
|
||||
|
||||
print(f"处理了文件: {file_path}")
|
||||
print(f"转录结果: {transcription}")
|
||||
|
||||
# 将结果存入Redis缓存
|
||||
redis_asr_client.setex(cache_key, 3600, transcription) # 缓存1小时
|
||||
|
||||
# 发布结果到Redis频道
|
||||
result_data = {
|
||||
'client_id': client_id,
|
||||
'transcription': transcription
|
||||
}
|
||||
redis_asr_client.publish('asr_results', json.dumps(result_data))
|
||||
|
||||
# 设置任务状态为 "completed"
|
||||
redis_task_client.set(f"task_status:{cache_key}", "completed")
|
||||
|
||||
# 清理临时文件
|
||||
os.remove(file_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理音频文件时发生错误: {str(e)}")
|
||||
# 设置任务状态为 "error"
|
||||
redis_task_client.set(f"task_status:{cache_key}", "error")
|
||||
|
||||
def kafka_consumer():
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
|
||||
group_id='asr_group',
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True
|
||||
)
|
||||
|
||||
print(f"ASR消费者已启动")
|
||||
|
||||
for message in consumer:
|
||||
try:
|
||||
task = message.value
|
||||
file_path = task.get('file_path')
|
||||
task_id = task.get('task_id')
|
||||
status = task.get('status')
|
||||
|
||||
if not file_path or not task_id or status != 'queued':
|
||||
print(f"收到无效任务: {task}")
|
||||
continue
|
||||
|
||||
cache_key = f"asr:{task_id}"
|
||||
client_id = task_id # 使用task_id作为client_id
|
||||
|
||||
print(f"开始处理任务: {cache_key}")
|
||||
process_audio(file_path, client_id, cache_key)
|
||||
print(f"完成处理任务: {cache_key}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理消息时发生错误: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("启动Kafka消费者处理ASR请求...")
|
||||
kafka_consumer()
|
||||
@@ -1,117 +0,0 @@
|
||||
from kafka import KafkaConsumer
|
||||
import json
|
||||
import asyncio
|
||||
import redis
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import requests
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
|
||||
# Kafka 设置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
|
||||
KAFKA_CONSUMER_GROUP = 'chat_group'
|
||||
KAFKA_CONSUMER_NUM = 3 # 消费者数量
|
||||
|
||||
# Redis 设置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_CHAT_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
redis_task_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_TASK_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
|
||||
|
||||
# 创建Kafka消费者
|
||||
def create_kafka_consumer():
|
||||
return KafkaConsumer(
|
||||
KAFKA_CHAT_TOPIC,
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
auto_offset_reset='latest',
|
||||
enable_auto_commit=True,
|
||||
group_id=KAFKA_CONSUMER_GROUP,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
async def process_chat_request(chat_request):
|
||||
try:
|
||||
session_id = chat_request['session_id']
|
||||
query = chat_request['query']
|
||||
model = chat_request.get('model', 'qwen2.5:3b')
|
||||
|
||||
# 设置任务状态为 "processing"
|
||||
redis_task_client.set(f"task_status:{session_id}", "processing")
|
||||
|
||||
# 从Redis获取历史记录
|
||||
history = json.loads(redis_client.get(session_id) or '[]')
|
||||
|
||||
# 构建包含历史对话的完整提示
|
||||
full_prompt = DEFAULT_SYSTEM_PROMPT + "\n"
|
||||
for past_query, past_response in history:
|
||||
full_prompt += f"用户: {past_query}\n助手: {past_response}\n"
|
||||
full_prompt += f"用户: {query}\n助手:"
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": full_prompt,
|
||||
"stream": True,
|
||||
"temperature": 0
|
||||
}
|
||||
|
||||
response = requests.post("http://127.0.0.1:11434/api/generate", json=data, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
text_output = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
json_data = json.loads(line)
|
||||
if 'response' in json_data:
|
||||
text_output += json_data['response']
|
||||
|
||||
# 更新历史记录
|
||||
history.append((query, text_output))
|
||||
redis_client.set(session_id, json.dumps(history))
|
||||
|
||||
# 设置任务状态为 "completed"
|
||||
redis_task_client.set(f"task_status:{session_id}", "completed")
|
||||
|
||||
print(f"处理完成 session {session_id}: {text_output}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理 session {chat_request['session_id']} 时出错: {str(e)}")
|
||||
# 设置任务状态为 "error"
|
||||
redis_task_client.set(f"task_status:{chat_request['session_id']}", "error")
|
||||
|
||||
def kafka_consumer_thread(consumer_id):
|
||||
consumer = create_kafka_consumer()
|
||||
print(f"消费者 {consumer_id} 已启动")
|
||||
for message in consumer:
|
||||
chat_request = message.value
|
||||
asyncio.run(process_chat_request(chat_request))
|
||||
|
||||
def main():
|
||||
print("启动Kafka消费者处理聊天请求...")
|
||||
with ThreadPoolExecutor(max_workers=KAFKA_CONSUMER_NUM) as executor:
|
||||
for i in range(KAFKA_CONSUMER_NUM):
|
||||
executor.submit(kafka_consumer_thread, i)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,182 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import requests
|
||||
import json
|
||||
from typing import List, Tuple
|
||||
from kafka import KafkaConsumer, TopicPartition
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
import asyncio
|
||||
import redis
|
||||
import uuid
|
||||
import logging
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import torch
|
||||
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
app = FastAPI()
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
torch.cuda.set_device(device)
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load MiniCPM3-4B model
|
||||
path = "/home/zydi/worker_chat/api/OpenBMB/MiniCPM3-4B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True)
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
# CORS 配置
|
||||
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Kafka 设置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_MINI3_TOPIC')
|
||||
KAFKA_CONSUMER_GROUP = 'mini3_group'
|
||||
KAFKA_CONSUMER_NUM = 1
|
||||
|
||||
# Redis 设置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_DB = int(os.getenv('REDIS_MINI3_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB,
|
||||
password=REDIS_PASSWORD # 使用密码进行认证
|
||||
)
|
||||
# 创建Kafka消费者
|
||||
def create_kafka_consumer():
|
||||
return KafkaConsumer(
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
group_id=KAFKA_CONSUMER_GROUP,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# Kafka消费者函数
|
||||
def kafka_consumer(consumer, consumer_id):
|
||||
# 获取消费者分配的分区
|
||||
consumer.subscribe([KAFKA_TOPIC])
|
||||
partitions = consumer.assignment()
|
||||
|
||||
logger.info(f"消费者 {consumer_id} 被分配了以下分区: {[p.partition for p in partitions]}")
|
||||
|
||||
for message in consumer:
|
||||
partition = message.partition
|
||||
offset = message.offset
|
||||
chat_request = message.value # 直接使用 message.value,它已经是一个字典
|
||||
session_id = chat_request['session_id']
|
||||
query = chat_request['query']
|
||||
|
||||
logger.info(f"消费者 {consumer_id} 正在处理来自分区 {partition} 的消息:")
|
||||
|
||||
asyncio.run(process_chat_request(chat_request))
|
||||
|
||||
# 启动Kafka消费者线程
|
||||
def start_kafka_consumers(num_consumers=KAFKA_CONSUMER_NUM):
|
||||
consumers = []
|
||||
for i in range(num_consumers):
|
||||
consumer = create_kafka_consumer()
|
||||
consumer_thread = threading.Thread(target=kafka_consumer, args=(consumer, i), daemon=True)
|
||||
consumer_thread.start()
|
||||
consumers.append((consumer, consumer_thread))
|
||||
return consumers
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str
|
||||
query: str
|
||||
model: str = "minicpm3-4b"
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
response: str
|
||||
history: List[Tuple[str, str]]
|
||||
|
||||
# 处理聊天请求的异步函数
|
||||
async def process_chat_request(chat_request):
|
||||
try:
|
||||
response = await chat(ChatRequest(**chat_request))
|
||||
print(f"Processed message for session {chat_request['session_id']}: {response}")
|
||||
except Exception as e:
|
||||
print(f"Error processing message for session {chat_request['session_id']}: {str(e)}")
|
||||
|
||||
@app.post("/mini3", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest):
|
||||
session_id = request.session_id
|
||||
query = request.query
|
||||
|
||||
# 从Redis获取历史记录
|
||||
history = json.loads(redis_client.get(session_id) or '[]')
|
||||
|
||||
# 构建包含历史对话的完整提示
|
||||
messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
|
||||
for past_query, past_response in history:
|
||||
messages.append({"role": "user", "content": past_query})
|
||||
messages.append({"role": "assistant", "content": past_response})
|
||||
messages.append({"role": "user", "content": query})
|
||||
|
||||
try:
|
||||
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
|
||||
|
||||
# 创建注意力掩码
|
||||
attention_mask = model_inputs.ne(tokenizer.pad_token_id).long()
|
||||
|
||||
# 将输入移动到正确的设备(CPU或GPU)
|
||||
model_inputs = model_inputs.to(device)
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
model_outputs = model.generate(
|
||||
model_inputs,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=1024,
|
||||
top_p=0.7,
|
||||
temperature=0.7,
|
||||
pad_token_id=tokenizer.eos_token_id, # 将pad_token_id设置为eos_token_id
|
||||
do_sample=True
|
||||
)
|
||||
|
||||
output_token_ids = model_outputs[0][len(model_inputs[0]):]
|
||||
text_output = tokenizer.decode(output_token_ids, skip_special_tokens=True)
|
||||
|
||||
# 更新历史记录
|
||||
history.append((query, text_output))
|
||||
redis_client.set(session_id, json.dumps(history))
|
||||
|
||||
return ChatResponse(response=text_output, history=history)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/start_chat")
|
||||
async def start_chat():
|
||||
session_id = str(uuid.uuid4())
|
||||
return {"session_id": session_id}
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 启动Kafka消费者线程
|
||||
start_kafka_consumers()
|
||||
|
||||
# 启动FastAPI服务器
|
||||
uvicorn.run(app, host="0.0.0.0", port=6003)
|
||||
@@ -1,63 +0,0 @@
|
||||
import os
|
||||
from moviepy.editor import VideoFileClip
|
||||
|
||||
def mp4_to_wav(input_file, output_file):
|
||||
"""
|
||||
将MP4文件转换为WAV格式
|
||||
|
||||
:param input_file: 输入的MP4文件路径
|
||||
:param output_file: 输出的WAV文件路径
|
||||
"""
|
||||
try:
|
||||
# 加载视频文件
|
||||
video = VideoFileClip(input_file)
|
||||
|
||||
# 提取音频
|
||||
audio = video.audio
|
||||
|
||||
# 将音频写入WAV文件
|
||||
audio.write_audiofile(output_file)
|
||||
|
||||
# 关闭视频和音频对象
|
||||
audio.close()
|
||||
video.close()
|
||||
|
||||
print(f"转换成功: {input_file} -> {output_file}")
|
||||
except Exception as e:
|
||||
print(f"转换失败: {input_file} - {str(e)}")
|
||||
|
||||
def process_directory(directory):
|
||||
"""
|
||||
处理目录中的所有MP4文件
|
||||
|
||||
:param directory: 包含MP4文件的目录路径
|
||||
"""
|
||||
for filename in os.listdir(directory):
|
||||
if filename.lower().endswith('.mp4'):
|
||||
input_file = os.path.join(directory, filename)
|
||||
output_file = os.path.splitext(input_file)[0] + ".wav"
|
||||
mp4_to_wav(input_file, output_file)
|
||||
|
||||
def main():
|
||||
# 获取输入路径
|
||||
input_path = input("请输入MP4文件或包含MP4文件的目录路径: ").strip()
|
||||
|
||||
# 检查输入路径是否存在
|
||||
if not os.path.exists(input_path):
|
||||
print("错误: 输入路径不存在")
|
||||
return
|
||||
|
||||
# 判断输入路径是文件还是目录
|
||||
if os.path.isfile(input_path):
|
||||
if not input_path.lower().endswith('.mp4'):
|
||||
print("错误: 输入文件不是MP4格式")
|
||||
return
|
||||
output_file = os.path.splitext(input_path)[0] + ".wav"
|
||||
mp4_to_wav(input_path, output_file)
|
||||
elif os.path.isdir(input_path):
|
||||
process_directory(input_path)
|
||||
else:
|
||||
print("错误: 输入路径既不是文件也不是目录")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,170 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import requests
|
||||
import json
|
||||
from typing import List, Tuple
|
||||
from kafka import KafkaConsumer, TopicPartition
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
import asyncio
|
||||
import redis
|
||||
import uuid
|
||||
import logging
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import torch
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
app = FastAPI()
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
torch.cuda.set_device(device)
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
# CORS 配置
|
||||
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Kafka 设置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
|
||||
KAFKA_CONSUMER_GROUP = 'chat_group'
|
||||
KAFKA_CONSUMER_NUM = 1
|
||||
|
||||
# Redis 设置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_DB = int(os.getenv('REDIS_CHAT_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB,
|
||||
password=REDIS_PASSWORD # 使用密码进行认证
|
||||
)
|
||||
# 创建Kafka消费者
|
||||
def create_kafka_consumer():
|
||||
return KafkaConsumer(
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
auto_offset_reset='earliest',
|
||||
enable_auto_commit=True,
|
||||
group_id=KAFKA_CONSUMER_GROUP,
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
# Kafka消费者函数
|
||||
def kafka_consumer(consumer, consumer_id):
|
||||
# 获取消费者分配的分区
|
||||
consumer.subscribe([KAFKA_TOPIC])
|
||||
partitions = consumer.assignment()
|
||||
|
||||
logger.info(f"消费者 {consumer_id} 被分配了以下分区: {[p.partition for p in partitions]}")
|
||||
|
||||
for message in consumer:
|
||||
partition = message.partition
|
||||
offset = message.offset
|
||||
chat_request = message.value # 直接使用 message.value,它已经是一个字典
|
||||
session_id = chat_request['session_id']
|
||||
query = chat_request['query']
|
||||
|
||||
logger.info(f"消费者 {consumer_id} 正在处理来自分区 {partition} 的消息:")
|
||||
|
||||
asyncio.run(process_chat_request(chat_request))
|
||||
|
||||
# 启动Kafka消费者线程
|
||||
def start_kafka_consumers(num_consumers=KAFKA_CONSUMER_NUM):
|
||||
consumers = []
|
||||
for i in range(num_consumers):
|
||||
consumer = create_kafka_consumer()
|
||||
consumer_thread = threading.Thread(target=kafka_consumer, args=(consumer, i), daemon=True)
|
||||
consumer_thread.start()
|
||||
consumers.append((consumer, consumer_thread))
|
||||
return consumers
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str
|
||||
query: str
|
||||
model: str = "qwen2.5:3b"
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
response: str
|
||||
history: List[Tuple[str, str]]
|
||||
|
||||
# 处理聊天请求的异步函数
|
||||
async def process_chat_request(chat_request):
|
||||
try:
|
||||
response = await chat(ChatRequest(**chat_request))
|
||||
print(f"Processed message for session {chat_request['session_id']}: {response}")
|
||||
except Exception as e:
|
||||
print(f"Error processing message for session {chat_request['session_id']}: {str(e)}")
|
||||
|
||||
@app.post("/chat", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest):
|
||||
session_id = request.session_id
|
||||
query = request.query
|
||||
model = request.model
|
||||
|
||||
# 从Redis获取历史记录
|
||||
history = json.loads(redis_client.get(session_id) or '[]')
|
||||
|
||||
# 构建包含历史对话的完整提示
|
||||
full_prompt = DEFAULT_SYSTEM_PROMPT + "\n"
|
||||
for past_query, past_response in history:
|
||||
full_prompt += f"用户: {past_query}\n助手: {past_response}\n"
|
||||
full_prompt += f"用户: {query}"
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": full_prompt,
|
||||
"stream": True,
|
||||
"temperature": 0
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post("http://127.0.0.1:11434/api/generate", json=data, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
text_output = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
json_data = json.loads(line)
|
||||
if 'response' in json_data:
|
||||
text_output += json_data['response']
|
||||
|
||||
# 更新历史记录
|
||||
history.append((query, text_output))
|
||||
redis_client.set(session_id, json.dumps(history))
|
||||
|
||||
return ChatResponse(response=text_output, history=history)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/start_chat")
|
||||
async def start_chat():
|
||||
session_id = str(uuid.uuid4())
|
||||
return {"session_id": session_id}
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 启动Kafka消费者线程
|
||||
start_kafka_consumers()
|
||||
|
||||
# 启动FastAPI服务器
|
||||
uvicorn.run(app, host="0.0.0.0", port=6001)
|
||||
@@ -1,136 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import httpx
|
||||
import json
|
||||
import redis
|
||||
from typing import List, Dict, Optional
|
||||
import logging
|
||||
import ollama
|
||||
import uuid
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Redis连接
|
||||
redis_client = redis.Redis(host='222.186.10.253', port=6379, db=14, password="Obscura@2024")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
model: Optional[str] = "qwen2.5:3b"
|
||||
prompt: str
|
||||
|
||||
class RawGenerateRequest(BaseModel):
|
||||
model: Optional[str] = "qwen2.5:3b"
|
||||
prompt: str
|
||||
system_prompt: Optional[str] = None
|
||||
stream: Optional[bool] = False
|
||||
raw: Optional[bool] = False
|
||||
format: Optional[str] = None
|
||||
options: Optional[Dict] = None
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
response: dict
|
||||
request_id: str
|
||||
|
||||
@app.post("/generate", response_model=GenerateResponse)
|
||||
async def generate(request: GenerateRequest):
|
||||
logger.info(f"收到请求: {request}")
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
response = ollama.chat(model=request.model, messages=[{"role": "user", "content": request.prompt}])
|
||||
full_response = response['message']['content']
|
||||
|
||||
request_data = {
|
||||
"model": request.model,
|
||||
"prompt": request.prompt,
|
||||
"response": full_response
|
||||
}
|
||||
|
||||
redis_client.set(f"request:{request_id}", json.dumps(request_data))
|
||||
|
||||
response_data = {
|
||||
"response": full_response,
|
||||
"model": request.model
|
||||
}
|
||||
|
||||
return GenerateResponse(response=response_data, request_id=request_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发生错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/generate")
|
||||
async def generate_without_history(request: RawGenerateRequest):
|
||||
"""
|
||||
处理无历史记录的生成请求。
|
||||
|
||||
参数:
|
||||
- request: RawGenerateRequest对象,包含生成请求的所有参数。
|
||||
|
||||
返回:
|
||||
- 包含生成结果的字典。
|
||||
"""
|
||||
try:
|
||||
response = ollama.generate(
|
||||
model=request.model,
|
||||
prompt=request.prompt,
|
||||
system=request.system_prompt,
|
||||
format=request.format,
|
||||
options=request.options,
|
||||
stream=request.stream
|
||||
)
|
||||
|
||||
response_data = {
|
||||
"model": request.model,
|
||||
"response": response['response'],
|
||||
"done": True,
|
||||
"context": response.get('context'),
|
||||
"total_duration": response.get('total_duration'),
|
||||
"load_duration": response.get('load_duration'),
|
||||
"prompt_eval_count": response.get('prompt_eval_count'),
|
||||
"prompt_eval_duration": response.get('prompt_eval_duration'),
|
||||
"eval_count": response.get('eval_count'),
|
||||
"eval_duration": response.get('eval_duration')
|
||||
}
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
redis_client.set(f"request:{request_id}", json.dumps(response_data))
|
||||
|
||||
return response_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发生未预期的错误: {e}")
|
||||
logger.exception("详细错误信息:")
|
||||
raise HTTPException(status_code=500, detail=f"处理Ollama请求时发生错误: {str(e)}")
|
||||
|
||||
@app.get("/request/{request_id}", response_model=Dict)
|
||||
async def get_request(request_id: str):
|
||||
request_data = redis_client.get(f"request:{request_id}")
|
||||
if request_data:
|
||||
return json.loads(request_data)
|
||||
raise HTTPException(status_code=404, detail="请求未找到")
|
||||
|
||||
@app.get("/models")
|
||||
async def list_models():
|
||||
return ollama.list()
|
||||
|
||||
@app.get("/models/{model_name}")
|
||||
async def show_model(model_name: str):
|
||||
return ollama.show(model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=7000)
|
||||
@@ -1,319 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security import APIKeyHeader
|
||||
from pydantic import BaseModel
|
||||
from kafka import KafkaProducer
|
||||
from redis import Redis
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from dotenv import load_dotenv
|
||||
import tempfile
|
||||
import hashlib
|
||||
import asyncio
|
||||
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI()
|
||||
v1_chat_app = FastAPI()
|
||||
app.mount("/v1_chat", v1_chat_app)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB'))
|
||||
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
|
||||
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
|
||||
REDIS_API_DB = int(os.getenv('REDIS_API_DB'))
|
||||
REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB'))
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||
|
||||
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||
KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
||||
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
|
||||
|
||||
# 初始化 Kafka Producer
|
||||
producer = KafkaProducer(
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB)
|
||||
redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB)
|
||||
redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB)
|
||||
redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB)
|
||||
redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB)
|
||||
redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB)
|
||||
|
||||
# 定义API密钥头部
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
def get_audio_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# 验证API密钥的函数
|
||||
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||
if api_key and api_key.startswith("Bearer "):
|
||||
key = api_key.split(" ")[1]
|
||||
if key.startswith("obs-"):
|
||||
return key
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的API密钥",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
"api_key": api_key,
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str):
|
||||
api_key = api_key_info['api_key']
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
# 发送任务到Kafka
|
||||
producer.send(kafka_topic, task_data)
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return {
|
||||
"message": f"{model_name.upper()}请求已排队等待处理",
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
}
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str
|
||||
query: str
|
||||
model: str = "qwen2.5:3b"
|
||||
|
||||
|
||||
# 添加WebSocket连接管理
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, client_id: str):
|
||||
await websocket.accept()
|
||||
self.active_connections[client_id] = websocket
|
||||
|
||||
def disconnect(self, client_id: str):
|
||||
self.active_connections.pop(client_id, None)
|
||||
|
||||
async def send_message(self, message: str, client_id: str):
|
||||
if client_id in self.active_connections:
|
||||
await self.active_connections[client_id].send_text(message)
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
@v1_chat_app.websocket("/ws/{client_id}")
|
||||
async def websocket_endpoint(websocket: WebSocket, client_id: str):
|
||||
await manager.connect(websocket, client_id)
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(client_id)
|
||||
|
||||
# 修改TTS请求处理函数
|
||||
@v1_chat_app.post("/tts")
|
||||
async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
task_data = {
|
||||
"task_id": task_id,
|
||||
"text": request.text,
|
||||
"status": "queued",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
redis_task_client.set(f"task_status:{task_id}", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "tts", 100, task_data, KAFKA_TTS_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
|
||||
# 将任务ID存储到Redis,以便后续WebSocket通信使用
|
||||
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
|
||||
|
||||
return result
|
||||
|
||||
# 修改ASR请求处理函数
|
||||
@v1_chat_app.post("/asr")
|
||||
async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
UPLOAD_DIR = "/obscura/task/audio_upload"
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav")
|
||||
|
||||
with open(file_path, "wb") as temp_audio:
|
||||
content = await audio.read()
|
||||
temp_audio.write(content)
|
||||
|
||||
task_data = {
|
||||
'file_path': file_path,
|
||||
'task_id': task_id,
|
||||
'status': 'queued'
|
||||
}
|
||||
|
||||
redis_task_client.set(f"task_status:{task_id}", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "asr", 100, task_data, KAFKA_ASR_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
|
||||
# 将任务ID存储到Redis,以便后续WebSocket通信使用
|
||||
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
|
||||
|
||||
return result
|
||||
|
||||
# 修改聊天请求处理函数
|
||||
@v1_chat_app.post("/chat")
|
||||
async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
task_data = {
|
||||
"task_id": task_id,
|
||||
"session_id": request.session_id,
|
||||
"query": request.query,
|
||||
"model": request.model,
|
||||
"status": "queued",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
redis_task_client.set(f"task_status:{task_id}", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "chat", 100, task_data, KAFKA_CHAT_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
|
||||
# 将任务ID存储到Redis,以便后续WebSocket通信使用
|
||||
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
|
||||
|
||||
return result
|
||||
|
||||
@v1_chat_app.get("/chat_result/{task_id}")
|
||||
async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
# 从Redis任务数据库获取任务状态
|
||||
task_status = redis_task_client.get(f"task_status:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从Redis聊天结果数据库获取聊天结果
|
||||
chat_result = redis_chat_client.get(task_id)
|
||||
if chat_result:
|
||||
result = json.loads(chat_result)
|
||||
return {
|
||||
"status": "completed",
|
||||
"history": result # 直接返回整个历史记录
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
@v1_chat_app.get("/tts_result/{task_id}")
|
||||
async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
# 从Redis任务数据库获取任务状态
|
||||
task_status = redis_task_client.get(f"task_status:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从Redis TTS结果数据库获取音频文件路径
|
||||
audio_info = redis_tts_client.get(task_id)
|
||||
if audio_info:
|
||||
audio_path = json.loads(audio_info)['path']
|
||||
return {
|
||||
"status": "completed",
|
||||
"audio_path": audio_path
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
@v1_chat_app.get("/asr_result/{task_id}")
|
||||
async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
# 从Redis任务数据库获取任务状态
|
||||
task_status = redis_task_client.get(f"task_status:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从Redis ASR结果数据库获取转录结果
|
||||
transcription = redis_asr_client.get(task_id)
|
||||
if transcription:
|
||||
return {
|
||||
"status": "completed",
|
||||
"transcription": transcription.decode('utf-8')
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8008)
|
||||
@@ -1,180 +0,0 @@
|
||||
import os
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
import uvicorn
|
||||
import redis
|
||||
import hashlib
|
||||
import json
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
import threading
|
||||
import time
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import torch
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# FastAPI configuration
|
||||
app = FastAPI()
|
||||
i18n = I18nAuto()
|
||||
|
||||
# CORS configuration
|
||||
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
||||
|
||||
# Redis configuration
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_DB = int(os.getenv('REDIS_TTS_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# Kafka configuration
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||
# KAFKA_GROUP_ID = 'tts_group'
|
||||
KAFKA_CONSUMER_THREADS = 1
|
||||
|
||||
# TTS configuration
|
||||
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
|
||||
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
|
||||
REF_AUDIO_PATH = os.getenv('REF_AUDIO_PATH')
|
||||
REF_TEXT_PATH = os.getenv('REF_TEXT_PATH')
|
||||
REF_LANGUAGE = os.getenv('REF_LANGUAGE')
|
||||
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
|
||||
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
|
||||
|
||||
# Initialize FastAPI CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize Redis client
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
# Initialize Kafka producer
|
||||
kafka_producer = KafkaProducer(bootstrap_servers=KAFKA_BROKER)
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str = Field(..., alias="text")
|
||||
|
||||
def get_audio_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# Initialize models at startup
|
||||
print("Initializing models...")
|
||||
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
|
||||
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
|
||||
|
||||
# Read reference text
|
||||
with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file:
|
||||
ref_text = file.read()
|
||||
|
||||
print("Models initialized successfully.")
|
||||
|
||||
def synthesize(target_text, output_path):
|
||||
# Synthesize audio
|
||||
with torch.cuda.device(device):
|
||||
synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(REF_LANGUAGE),
|
||||
text=target_text,
|
||||
text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
audio_hash = get_audio_hash(target_text)
|
||||
output_wav_path = os.path.join(output_path, f"{audio_hash}.wav")
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
return output_wav_path
|
||||
else:
|
||||
return None
|
||||
|
||||
@app.post("/tts")
|
||||
async def synthesize_audio(request: TTSRequest):
|
||||
try:
|
||||
print(f"Received TTS request: {request.dict()}")
|
||||
target_text = request.text
|
||||
audio_hash = get_audio_hash(target_text)
|
||||
|
||||
# Check Redis cache
|
||||
cached_audio = redis_client.get(audio_hash)
|
||||
if cached_audio:
|
||||
audio_info = json.loads(cached_audio)
|
||||
return FileResponse(audio_info['path'], media_type="audio/wav")
|
||||
|
||||
# Check file system
|
||||
file_path = os.path.join(OUTPUT_PATH, f"{audio_hash}.wav")
|
||||
if os.path.exists(file_path):
|
||||
# Cache the file path in Redis
|
||||
redis_client.set(audio_hash, json.dumps({"path": file_path}))
|
||||
return FileResponse(file_path, media_type="audio/wav")
|
||||
|
||||
# Send message to Kafka
|
||||
kafka_producer.send(KAFKA_TOPIC, json.dumps({
|
||||
'text': target_text,
|
||||
'audio_hash': audio_hash
|
||||
}).encode('utf-8'))
|
||||
|
||||
# Wait for the audio to be generated (you might want to implement a more sophisticated waiting mechanism)
|
||||
for _ in range(60): # Wait for up to 30 seconds
|
||||
if os.path.exists(file_path):
|
||||
return FileResponse(file_path, media_type="audio/wav")
|
||||
time.sleep(1)
|
||||
|
||||
# If audio is not generated within the timeout
|
||||
raise HTTPException(status_code=504, detail="Audio generation timed out")
|
||||
except Exception as e:
|
||||
print(f"Error processing TTS request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "TTS API is running"}
|
||||
|
||||
def kafka_consumer_thread():
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
# group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='latest',
|
||||
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
|
||||
)
|
||||
|
||||
for message in consumer:
|
||||
target_text = message.value['text']
|
||||
audio_hash = message.value['audio_hash']
|
||||
|
||||
output_path = synthesize(target_text, OUTPUT_PATH)
|
||||
|
||||
if output_path:
|
||||
redis_client.set(audio_hash, json.dumps({"path": output_path}))
|
||||
print(f"Audio synthesized successfully: {output_path}")
|
||||
else:
|
||||
print("Failed to synthesize audio")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Start Kafka consumer threads
|
||||
torch.cuda.set_device(device)
|
||||
for _ in range(KAFKA_CONSUMER_THREADS):
|
||||
consumer_thread = threading.Thread(target=kafka_consumer_thread)
|
||||
consumer_thread.start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=6002)
|
||||
@@ -1,180 +0,0 @@
|
||||
import os
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
import uvicorn
|
||||
import redis
|
||||
import hashlib
|
||||
import json
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
import threading
|
||||
import time
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import torch
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# FastAPI configuration
|
||||
app = FastAPI()
|
||||
i18n = I18nAuto()
|
||||
|
||||
# CORS configuration
|
||||
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
||||
|
||||
# Redis configuration
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_DB = int(os.getenv('REDIS_TTS_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# Kafka configuration
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||
# KAFKA_GROUP_ID = 'tts_group'
|
||||
KAFKA_CONSUMER_THREADS = 1
|
||||
|
||||
# TTS configuration
|
||||
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
|
||||
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
|
||||
REF_AUDIO_PATH = os.getenv('REF_AUDIO_KO_PATH')
|
||||
REF_TEXT_PATH = os.getenv('REF_TEXT_KO_PATH')
|
||||
REF_LANGUAGE = os.getenv('REF_KO_LANGUAGE')
|
||||
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
|
||||
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
|
||||
|
||||
# Initialize FastAPI CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize Redis client
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
# Initialize Kafka producer
|
||||
kafka_producer = KafkaProducer(bootstrap_servers=KAFKA_BROKER)
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str = Field(..., alias="text")
|
||||
|
||||
def get_audio_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# Initialize models at startup
|
||||
print("Initializing models...")
|
||||
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
|
||||
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
|
||||
|
||||
# Read reference text
|
||||
with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file:
|
||||
ref_text = file.read()
|
||||
|
||||
print("Models initialized successfully.")
|
||||
|
||||
def synthesize(target_text, output_path):
|
||||
# Synthesize audio
|
||||
with torch.cuda.device(device):
|
||||
synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(REF_LANGUAGE),
|
||||
text=target_text,
|
||||
text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
audio_hash = get_audio_hash(target_text)
|
||||
output_wav_path = os.path.join(output_path, f"{audio_hash}.wav")
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
return output_wav_path
|
||||
else:
|
||||
return None
|
||||
|
||||
@app.post("/tts_ko")
|
||||
async def synthesize_audio(request: TTSRequest):
|
||||
try:
|
||||
print(f"Received TTS request: {request.dict()}")
|
||||
target_text = request.text
|
||||
audio_hash = get_audio_hash(target_text)
|
||||
|
||||
# Check Redis cache
|
||||
cached_audio = redis_client.get(audio_hash)
|
||||
if cached_audio:
|
||||
audio_info = json.loads(cached_audio)
|
||||
return FileResponse(audio_info['path'], media_type="audio/wav")
|
||||
|
||||
# Check file system
|
||||
file_path = os.path.join(OUTPUT_PATH, f"{audio_hash}.wav")
|
||||
if os.path.exists(file_path):
|
||||
# Cache the file path in Redis
|
||||
redis_client.set(audio_hash, json.dumps({"path": file_path}))
|
||||
return FileResponse(file_path, media_type="audio/wav")
|
||||
|
||||
# Send message to Kafka
|
||||
kafka_producer.send(KAFKA_TOPIC, json.dumps({
|
||||
'text': target_text,
|
||||
'audio_hash': audio_hash
|
||||
}).encode('utf-8'))
|
||||
|
||||
# Wait for the audio to be generated (you might want to implement a more sophisticated waiting mechanism)
|
||||
for _ in range(60): # Wait for up to 30 seconds
|
||||
if os.path.exists(file_path):
|
||||
return FileResponse(file_path, media_type="audio/wav")
|
||||
time.sleep(1)
|
||||
|
||||
# If audio is not generated within the timeout
|
||||
raise HTTPException(status_code=504, detail="Audio generation timed out")
|
||||
except Exception as e:
|
||||
print(f"Error processing TTS request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "TTS API is running"}
|
||||
|
||||
def kafka_consumer_thread():
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
# group_id=KAFKA_GROUP_ID,
|
||||
auto_offset_reset='latest',
|
||||
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
|
||||
)
|
||||
|
||||
for message in consumer:
|
||||
target_text = message.value['text']
|
||||
audio_hash = message.value['audio_hash']
|
||||
|
||||
output_path = synthesize(target_text, OUTPUT_PATH)
|
||||
|
||||
if output_path:
|
||||
redis_client.set(audio_hash, json.dumps({"path": output_path}))
|
||||
print(f"Audio synthesized successfully: {output_path}")
|
||||
else:
|
||||
print("Failed to synthesize audio")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Start Kafka consumer threads
|
||||
torch.cuda.set_device(device)
|
||||
for _ in range(KAFKA_CONSUMER_THREADS):
|
||||
consumer_thread = threading.Thread(target=kafka_consumer_thread)
|
||||
consumer_thread.start()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=6003)
|
||||
@@ -1,176 +0,0 @@
|
||||
# 导入所需的库
|
||||
import os
|
||||
import soundfile as sf
|
||||
import redis
|
||||
import hashlib
|
||||
import json
|
||||
from kafka import KafkaConsumer
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
from dotenv import load_dotenv
|
||||
import torch
|
||||
|
||||
"""
|
||||
整体设计说明:
|
||||
这个脚本实现了一个文本到语音(TTS)的服务。它使用Kafka作为消息队列接收TTS任务,
|
||||
使用Redis存储任务状态和结果,并利用GPT-SoVITS模型进行语音合成。
|
||||
主要功能包括:
|
||||
1. 初始化配置和模型
|
||||
2. 提供语音合成功能
|
||||
3. 监听Kafka消息并处理TTS任务
|
||||
4. 将合成结果存储到Redis并更新任务状态
|
||||
"""
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 设置GPU设备(如果可用)
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
# 从环境变量中读取Redis配置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB')) # DB 2用于存储TTS结果
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) # DB 3用于存储任务状态
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# 从环境变量中读取Kafka配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||
|
||||
# 从环境变量中读取TTS相关配置
|
||||
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
|
||||
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
|
||||
REF_AUDIO_PATH = os.getenv('REF_AUDIO_ZN_PATH')
|
||||
REF_TEXT_PATH = os.getenv('REF_TEXT_ZN_PATH')
|
||||
REF_LANGUAGE = os.getenv('REF_LANGUAGE')
|
||||
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
|
||||
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
|
||||
|
||||
# 初始化Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_TTS_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
redis_task_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_TASK_DB,
|
||||
password=REDIS_PASSWORD
|
||||
)
|
||||
|
||||
# 初始化国际化工具
|
||||
i18n = I18nAuto()
|
||||
|
||||
def get_audio_hash(text):
|
||||
"""
|
||||
生成文本的MD5哈希值,用作音频文件名的一部分
|
||||
|
||||
参数:
|
||||
text (str): 需要生成哈希的文本
|
||||
|
||||
返回:
|
||||
str: 文本的MD5哈希值
|
||||
"""
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# 初始化模型
|
||||
print("正在初始化模型...")
|
||||
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
|
||||
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
|
||||
|
||||
# 读取参考文本
|
||||
with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file:
|
||||
ref_text = file.read()
|
||||
|
||||
print("模型初始化成功。")
|
||||
|
||||
def synthesize(target_text, output_wav_path):
|
||||
"""
|
||||
使用GPT-SoVITS模型合成语音
|
||||
|
||||
参数:
|
||||
target_text (str): 需要合成语音的目标文本
|
||||
output_wav_path (str): 输出音频文件的路径
|
||||
|
||||
返回:
|
||||
str: 如果成功,返回输出音频文件的路径;如果失败,返回None
|
||||
"""
|
||||
with torch.cuda.device(device):
|
||||
synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(REF_LANGUAGE),
|
||||
text=target_text,
|
||||
text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
return output_wav_path
|
||||
else:
|
||||
return None
|
||||
|
||||
def kafka_consumer():
|
||||
"""
|
||||
Kafka消费者函数,用于接收和处理TTS任务
|
||||
|
||||
该函数会持续监听Kafka的TTS主题,接收任务并进行处理:
|
||||
1. 接收任务信息
|
||||
2. 更新任务状态
|
||||
3. 调用synthesize函数合成语音
|
||||
4. 将结果保存到Redis
|
||||
5. 更新任务完成状态
|
||||
"""
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TTS_TOPIC,
|
||||
bootstrap_servers=KAFKA_BROKER,
|
||||
auto_offset_reset='latest',
|
||||
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
|
||||
)
|
||||
print(f"TTS消费者已启动")
|
||||
for message in consumer:
|
||||
try:
|
||||
task_id = message.value['task_id']
|
||||
target_text = message.value['text']
|
||||
text_hash = message.value['text_hash']
|
||||
|
||||
# 更新任务状态为 "processing"
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "processing")
|
||||
|
||||
output_wav_path = os.path.join(OUTPUT_PATH, f"{text_hash}.wav")
|
||||
|
||||
# 再次检查文件是否存在(以防在此期间被其他进程创建)
|
||||
if not os.path.exists(output_wav_path):
|
||||
output_path = synthesize(target_text, output_wav_path)
|
||||
else:
|
||||
output_path = output_wav_path
|
||||
|
||||
if output_path:
|
||||
# 将结果保存在 DB 2
|
||||
redis_client.set(f"tts:{task_id}", json.dumps({"path": output_path}))
|
||||
print(f"音频合成成功: {output_path}")
|
||||
|
||||
# 更新任务状态为 "completed"
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "completed")
|
||||
else:
|
||||
print("音频合成失败")
|
||||
|
||||
# 更新任务状态为 "failed"
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "failed")
|
||||
except Exception as e:
|
||||
print(f"处理消息时出错: {str(e)}")
|
||||
|
||||
# 更新任务状态为 "failed"
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "failed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 设置CUDA设备
|
||||
torch.cuda.set_device(device)
|
||||
# 启动Kafka消费者
|
||||
kafka_consumer()
|
||||
@@ -1,68 +0,0 @@
|
||||
import os
|
||||
import whisper
|
||||
import argparse
|
||||
|
||||
def transcribe_audio(model, audio_path):
|
||||
"""
|
||||
使用Whisper模型转录音频文件
|
||||
|
||||
:param model: 加载的Whisper模型
|
||||
:param audio_path: 音频文件路径
|
||||
:return: 转录的文本
|
||||
"""
|
||||
try:
|
||||
result = model.transcribe(audio_path)
|
||||
return result["text"]
|
||||
except Exception as e:
|
||||
print(f"转录失败 {audio_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def process_directory(directory, model):
|
||||
"""
|
||||
处理目录中的所有WAV文件
|
||||
|
||||
:param directory: 包含WAV文件的目录路径
|
||||
:param model: 加载的Whisper模型
|
||||
"""
|
||||
for filename in os.listdir(directory):
|
||||
if filename.lower().endswith('.wav'):
|
||||
input_file = os.path.join(directory, filename)
|
||||
output_file = os.path.splitext(input_file)[0] + ".txt"
|
||||
|
||||
print(f"正在处理: {input_file}")
|
||||
transcription = transcribe_audio(model, input_file)
|
||||
|
||||
if transcription:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(transcription)
|
||||
print(f"转录完成: {output_file}")
|
||||
else:
|
||||
print(f"转录失败: {input_file}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="使用Whisper将WAV文件转换为文本")
|
||||
parser.add_argument("input_path", help="输入的WAV文件或包含WAV文件的目录路径")
|
||||
parser.add_argument("--model", default="small", choices=["tiny", "base", "small", "medium", "large", "large-v3"], help="Whisper模型大小")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"正在加载Whisper模型 ({args.model})...")
|
||||
model = whisper.load_model(args.model)
|
||||
print("模型加载完成")
|
||||
|
||||
if os.path.isfile(args.input_path):
|
||||
if not args.input_path.lower().endswith('.wav'):
|
||||
print("错误: 输入文件不是WAV格式")
|
||||
return
|
||||
output_file = os.path.splitext(args.input_path)[0] + ".txt"
|
||||
transcription = transcribe_audio(model, args.input_path)
|
||||
if transcription:
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(transcription)
|
||||
print(f"转录完成: {output_file}")
|
||||
elif os.path.isdir(args.input_path):
|
||||
process_directory(args.input_path, model)
|
||||
else:
|
||||
print("错误: 输入路径既不是文件也不是目录")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1 +0,0 @@
|
||||
{"GPT": {"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"}, "SoVITS": {"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"}}
|
||||
@@ -1,186 +0,0 @@
|
||||
from fastapi import FastAPI, File, UploadFile, HTTPException, WebSocket
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import whisper
|
||||
import tempfile
|
||||
import uvicorn
|
||||
from kafka import KafkaProducer, KafkaConsumer
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
import json
|
||||
import threading
|
||||
import os
|
||||
import uuid
|
||||
import asyncio
|
||||
import logging
|
||||
import redis
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 设置要使用的GPU ID
|
||||
GPU_ID = 1 # 修改这个值来选择要使用的GPU
|
||||
|
||||
# 设置CUDA_VISIBLE_DEVICES环境变量
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# CORS 配置
|
||||
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
||||
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
print("正在加载Whisper模型...")
|
||||
model = whisper.load_model("large-v3")
|
||||
print("Whisper模型加载完成。")
|
||||
|
||||
# Kafka配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
||||
|
||||
# Redis配置
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_DB = int(os.getenv('REDIS_ASR_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB,
|
||||
password=REDIS_PASSWORD # 添加密码
|
||||
)
|
||||
|
||||
# Kafka生产者
|
||||
producer = KafkaProducer(
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
||||
)
|
||||
|
||||
# 存储WebSocket连接的字典
|
||||
active_connections = {}
|
||||
|
||||
@app.websocket("/asr/ws/{client_id}")
|
||||
async def websocket_endpoint(websocket: WebSocket, client_id: str):
|
||||
await websocket.accept()
|
||||
active_connections[client_id] = websocket
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# 设置接收超时
|
||||
data = await asyncio.wait_for(websocket.receive_text(), timeout=30)
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
else:
|
||||
await websocket.send_text(f"收到消息: {data}")
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
# 发送心跳
|
||||
await websocket.send_text("heartbeat")
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"客户端 {client_id} 断开连接")
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"客户端 {client_id} 断开连接")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket错误: {e}")
|
||||
finally:
|
||||
if client_id in active_connections:
|
||||
del active_connections[client_id]
|
||||
|
||||
|
||||
@app.post("/asr")
|
||||
async def transcribe(audio: UploadFile = File(...)):
|
||||
if not audio:
|
||||
raise HTTPException(status_code=400, detail="未提供音频文件")
|
||||
|
||||
client_id = str(uuid.uuid4())
|
||||
|
||||
# 生成缓存键
|
||||
cache_key = f"asr:{audio.filename}:{client_id}"
|
||||
|
||||
# 检查缓存
|
||||
cached_result = redis_client.get(cache_key)
|
||||
if cached_result:
|
||||
logger.info(f"缓存命中: {cache_key}")
|
||||
return {"message": "从缓存获取转录结果", "transcription": cached_result.decode('utf-8'), "client_id": client_id}
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio:
|
||||
content = await audio.read()
|
||||
temp_audio.write(content)
|
||||
temp_audio.flush()
|
||||
|
||||
task = {
|
||||
'file_path': temp_audio.name,
|
||||
'client_id': client_id,
|
||||
'cache_key': cache_key
|
||||
}
|
||||
producer.send(KAFKA_TOPIC, value=task)
|
||||
producer.flush()
|
||||
|
||||
logger.info(f"发送任务到Kafka: {task}")
|
||||
return {"message": "音频文件已接收并发送任务进行处理", "client_id": client_id}
|
||||
|
||||
async def send_transcription(client_id: str, transcription: str):
|
||||
if client_id in active_connections:
|
||||
websocket = active_connections[client_id]
|
||||
await websocket.send_json({"transcription": transcription})
|
||||
else:
|
||||
logger.warning(f"客户端 {client_id} 的WebSocket连接不存在")
|
||||
|
||||
def kafka_consumer(consumer_id):
|
||||
consumer = KafkaConsumer(
|
||||
KAFKA_TOPIC,
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
|
||||
group_id='asr_group',
|
||||
max_poll_interval_ms=300000
|
||||
)
|
||||
|
||||
for message in consumer:
|
||||
try:
|
||||
task = message.value
|
||||
file_path = task.get('file_path')
|
||||
client_id = task.get('client_id')
|
||||
cache_key = task.get('cache_key')
|
||||
|
||||
if not file_path or not client_id or not cache_key:
|
||||
logger.error(f"消费者 {consumer_id} 收到无效任务: {task}")
|
||||
consumer.commit()
|
||||
continue
|
||||
|
||||
result = model.transcribe(file_path)
|
||||
|
||||
logger.info(f"消费者 {consumer_id} 处理了文件: {file_path}")
|
||||
logger.info(f"转录结果: {result['text']}")
|
||||
|
||||
# 将结果存入Redis缓存
|
||||
redis_client.setex(cache_key, 3600, result['text']) # 缓存1小时
|
||||
|
||||
asyncio.run(send_transcription(client_id, result['text']))
|
||||
|
||||
os.remove(file_path)
|
||||
consumer.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"消费者 {consumer_id} 处理消息时发生错误: {str(e)}")
|
||||
|
||||
def start_consumers(num_consumers=1):
|
||||
for i in range(num_consumers):
|
||||
consumer_thread = threading.Thread(target=kafka_consumer, args=(i,))
|
||||
consumer_thread.start()
|
||||
|
||||
if __name__ == '__main__':
|
||||
start_consumers()
|
||||
uvicorn.run(app, host="0.0.0.0", port=6000)
|
||||
+30
-2
@@ -23,6 +23,10 @@ REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||
|
||||
OLLAMA_URL = os.getenv('OLLAMA_URL')
|
||||
OLLAMA_URLS = os.getenv('OLLAMA_URLS', OLLAMA_URL).split(',') # 兼容旧配置
|
||||
OLLAMA_TIMEOUT = int(os.getenv('OLLAMA_TIMEOUT', 10))
|
||||
|
||||
# 创建Redis客户端
|
||||
redis_client = redis.Redis(
|
||||
host=REDIS_HOST,
|
||||
@@ -51,6 +55,21 @@ def create_kafka_consumer():
|
||||
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
||||
)
|
||||
|
||||
async def try_ollama_request(url, data):
|
||||
"""尝试向单个 Ollama API 发送请求"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{url}/api/generate",
|
||||
json=data,
|
||||
stream=True,
|
||||
timeout=OLLAMA_TIMEOUT
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except Exception as e:
|
||||
print(f"API {url} 请求失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def process_chat_request(chat_request):
|
||||
try:
|
||||
task_id = chat_request['task_id']
|
||||
@@ -77,9 +96,17 @@ async def process_chat_request(chat_request):
|
||||
"temperature": 0
|
||||
}
|
||||
|
||||
response = requests.post("https://ffgregevrdcfyhtnhyudvr.myfastools.com/api/generate", json=data, stream=True)
|
||||
response.raise_for_status()
|
||||
# 尝试所有可用的 API 地址
|
||||
response = None
|
||||
for url in OLLAMA_URLS:
|
||||
response = await try_ollama_request(url, data)
|
||||
if response is not None:
|
||||
print(f"使用 API 地址: {url}")
|
||||
break
|
||||
|
||||
if response is None:
|
||||
raise Exception("所有 API 地址均不可用")
|
||||
|
||||
text_output = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
@@ -105,6 +132,7 @@ async def process_chat_request(chat_request):
|
||||
# 设置任务状态为 "error"
|
||||
redis_task_client.set(f"chat:{task_id}:status", "error")
|
||||
redis_task_client.set(f"chat:{task_id}:error", str(e))
|
||||
|
||||
def kafka_consumer_thread(consumer_id):
|
||||
consumer = create_kafka_consumer()
|
||||
print(f"消费者 {consumer_id} 已启动")
|
||||
|
||||
Regular → Executable
@@ -1,94 +0,0 @@
|
||||
# Kafka 配置
|
||||
KAFKA_BROKER=222.186.10.253:9092
|
||||
KAFKA_ASR_TOPIC=asr
|
||||
KAFKA_CHAT_TOPIC=chat
|
||||
KAFKA_TTS_TOPIC=tts
|
||||
|
||||
|
||||
# Redis 配置
|
||||
REDIS_HOST=150.158.144.159
|
||||
REDIS_PORT=13003
|
||||
REDIS_ASR_DB=12
|
||||
REDIS_CHAT_DB=13
|
||||
REDIS_TTS_DB=14
|
||||
REDIS_PASSWORD=Obscura@2024
|
||||
REDIS_API_DB=2
|
||||
REDIS_API_USAGE_DB=3
|
||||
REDIS_TASK_DB=11
|
||||
REDIS_SESSION_DB=63
|
||||
|
||||
REDIS_SESSION_DB_ZH=63
|
||||
REDIS_SESSION_DB_EN=62
|
||||
REDIS_SESSION_DB_KO=61
|
||||
|
||||
# CORS 配置
|
||||
# ALLOWED_ORIGINS=https://beta.obscura.work
|
||||
|
||||
|
||||
# GPT-SoVITS 配置
|
||||
GPT_MODEL_PATH=GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||
SOVITS_MODEL_PATH=GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||
REF_AUDIO_PATH=sample/woman.wav
|
||||
REF_TEXT_PATH=sample/woman.txt
|
||||
REF_LANGUAGE=中文
|
||||
TARGET_LANGUAGE=多语种混合
|
||||
OUTPUT_PATH=/obscura/task/audio_files
|
||||
|
||||
# VOICE_CONFIGS
|
||||
GIRL_REF_AUDIO=sample/gril.wav
|
||||
GIRL_REF_TEXT=sample/gril.txt
|
||||
|
||||
WOMAN_REF_AUDIO=sample/woman.wav
|
||||
WOMAN_REF_TEXT=sample/woman.txt
|
||||
|
||||
|
||||
MAN_REF_AUDIO=sample/man.wav
|
||||
MAN_REF_TEXT=sample/man.txt
|
||||
|
||||
LEIJUN_REF_AUDIO=sample/leijun.wav
|
||||
LEIJUN_REF_TEXT=sample/leijun.txt
|
||||
|
||||
DUFU_REF_AUDIO=sample/dufu.wav
|
||||
DUFU_REF_TEXT=sample/dufu.txt
|
||||
|
||||
HEJIONG_REF_AUDIO=sample/hejiong.wav
|
||||
HEJIONG_REF_TEXT=sample/hejiong.txt
|
||||
|
||||
MAHUATENG_REF_AUDIO=sample/mahuateng.wav
|
||||
MAHUATENG_REF_TEXT=sample/mahuateng.txt
|
||||
|
||||
LIDAN_REF_AUDIO=sample/lidan.wav
|
||||
LIDAN_REF_TEXT=sample/lidan.txt
|
||||
|
||||
YUHUA_REF_AUDIO=sample/yuhua.wav
|
||||
YUHUA_REF_TEXT=sample/yuhua.txt
|
||||
|
||||
LIUZHENYUN_REF_AUDIO=sample/liuzhenyun.wav
|
||||
LIUZHENYUN_REF_TEXT=sample/liuzhenyun.txt
|
||||
|
||||
DABING_REF_AUDIO=sample/dabing.wav
|
||||
DABING_REF_TEXT=sample/dabing.txt
|
||||
|
||||
LUOXIANG_REF_AUDIO=sample/luoxiang.wav
|
||||
LUOXIANG_REF_TEXT=sample/luoxiang.txt
|
||||
|
||||
XUZHIYUAN_REF_AUDIO=sample/xuzhiyuan.wav
|
||||
XUZHIYUAN_REF_TEXT=sample/xuzhiyuan.txt
|
||||
|
||||
|
||||
REDIS_GIRL_DB = 15
|
||||
REDIS_WOMAN_DB = 16
|
||||
REDIS_MAN_DB = 17
|
||||
REDIS_LEIJUN_DB = 18
|
||||
REDIS_DUFU_DB = 19
|
||||
REDIS_HEJIONG_DB = 20
|
||||
REDIS_MAHUATENG_DB = 21
|
||||
REDIS_LIDAN_DB = 22
|
||||
REDIS_DABING_DB = 23
|
||||
REDIS_LUOXIANG_DB = 24
|
||||
REDIS_XUZHIYUAN_DB = 25
|
||||
REDIS_YUHUA_DB = 26
|
||||
REDIS_LIUZHENYUN_DB = 27
|
||||
|
||||
|
||||
|
||||
@@ -1,406 +0,0 @@
|
||||
from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security import APIKeyHeader
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
from kafka import KafkaProducer
|
||||
from redis import Redis
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from dotenv import load_dotenv
|
||||
import tempfile
|
||||
import hashlib
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# 在文件顶部添加这个函数
|
||||
def get_audio_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
|
||||
app = FastAPI()
|
||||
v1_chat_app = FastAPI()
|
||||
app.mount("/v1_chat", v1_chat_app)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 配置
|
||||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||||
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB'))
|
||||
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
|
||||
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
|
||||
REDIS_API_DB = int(os.getenv('REDIS_API_DB'))
|
||||
REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB'))
|
||||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||||
|
||||
|
||||
# Redis 配置
|
||||
REDIS_GIRL_DB = int(os.getenv('REDIS_GIRL_DB'))
|
||||
REDIS_WOMAN_DB = int(os.getenv('REDIS_WOMAN_DB'))
|
||||
REDIS_MAN_DB = int(os.getenv('REDIS_MAN_DB'))
|
||||
REDIS_LEIJUN_DB = int(os.getenv('REDIS_LEIJUN_DB'))
|
||||
REDIS_DUFU_DB = int(os.getenv('REDIS_DUFU_DB'))
|
||||
REDIS_HEJIONG_DB = int(os.getenv('REDIS_HEJIONG_DB'))
|
||||
REDIS_MAHUATENG_DB = int(os.getenv('REDIS_MAHUATENG_DB'))
|
||||
REDIS_LIDAN_DB = int(os.getenv('REDIS_LIDAN_DB'))
|
||||
REDIS_YUHUA_DB = int(os.getenv('REDIS_YUHUA_DB'))
|
||||
REDIS_LIUZHENYUN_DB = int(os.getenv('REDIS_LIUZHENYUN_DB'))
|
||||
REDIS_DABING_DB = int(os.getenv('REDIS_DABING_DB'))
|
||||
REDIS_LUOXIANG_DB = int(os.getenv('REDIS_LUOXIANG_DB'))
|
||||
REDIS_XUZHIYUAN_DB = int(os.getenv('REDIS_XUZHIYUAN_DB'))
|
||||
|
||||
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||||
KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
|
||||
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
|
||||
|
||||
OUTPUT_PATH= os.getenv('OUTPUT_PATH')
|
||||
|
||||
# 初始化 Kafka Producer
|
||||
producer = KafkaProducer(
|
||||
bootstrap_servers=[KAFKA_BROKER],
|
||||
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
||||
)
|
||||
|
||||
# 初始化 Redis
|
||||
redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB)
|
||||
redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB)
|
||||
redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB)
|
||||
redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB)
|
||||
redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB)
|
||||
redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB)
|
||||
|
||||
redis_tts_girl = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_GIRL_DB)
|
||||
redis_tts_woman = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_WOMAN_DB)
|
||||
redis_tts_man = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAN_DB)
|
||||
redis_tts_leijun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LEIJUN_DB)
|
||||
redis_tts_dufu = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DUFU_DB)
|
||||
redis_tts_hejiong = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_HEJIONG_DB)
|
||||
redis_tts_mahuateng = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAHUATENG_DB)
|
||||
redis_tts_lidan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIDAN_DB)
|
||||
redis_tts_yuhua = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_YUHUA_DB)
|
||||
redis_tts_liuzhenyun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIUZHENYUN_DB)
|
||||
redis_tts_dabing = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DABING_DB)
|
||||
redis_tts_luoxiang = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LUOXIANG_DB)
|
||||
redis_tts_xuzhiyuan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_XUZHIYUAN_DB)
|
||||
|
||||
# 创建一个音色到对应 Redis 客户端的映射
|
||||
voice_to_redis = {
|
||||
"default": redis_tts_girl,
|
||||
"girl": redis_tts_girl,
|
||||
"woman": redis_tts_woman,
|
||||
"man": redis_tts_man,
|
||||
"leijun": redis_tts_leijun,
|
||||
"dufu": redis_tts_dufu,
|
||||
"hejiong": redis_tts_hejiong,
|
||||
"mahuateng": redis_tts_mahuateng,
|
||||
"lidan": redis_tts_lidan,
|
||||
"yuhua": redis_tts_yuhua,
|
||||
"liuzhenyun": redis_tts_liuzhenyun,
|
||||
"dabing": redis_tts_dabing,
|
||||
"luoxiang": redis_tts_luoxiang,
|
||||
"xuzhiyuan": redis_tts_xuzhiyuan
|
||||
}
|
||||
|
||||
|
||||
# 定义API密钥头部
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
def get_audio_hash(text):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
# 验证API密钥的函数
|
||||
async def get_api_key(api_key: str = Security(api_key_header)):
|
||||
if api_key and api_key.startswith("Bearer "):
|
||||
key = api_key.split(" ")[1]
|
||||
if key.startswith("obs-"):
|
||||
return key
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="无效的API密钥",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def verify_api_key(api_key: str = Depends(get_api_key)):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
|
||||
api_key_info = redis_api_client.hgetall(redis_key)
|
||||
|
||||
if not api_key_info:
|
||||
raise HTTPException(status_code=401, detail="无效的API密钥")
|
||||
|
||||
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
|
||||
|
||||
if api_key_info.get('is_active') != '1':
|
||||
raise HTTPException(status_code=401, detail="API密钥已停用")
|
||||
|
||||
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(status_code=401, detail="API密钥已过期")
|
||||
|
||||
usage_info = redis_api_usage_client.hgetall(redis_key)
|
||||
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
|
||||
|
||||
return {
|
||||
"api_key": api_key,
|
||||
**api_key_info,
|
||||
**usage_info
|
||||
}
|
||||
|
||||
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
|
||||
redis_key = f"api_key:{api_key}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
pipe = redis_api_usage_client.pipeline()
|
||||
|
||||
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
|
||||
pipe.hset(redis_key, "last_used_at", current_time)
|
||||
|
||||
model_tokens_field = f"{model_name}_tokens_used"
|
||||
model_last_used_field = f"{model_name}_last_used_at"
|
||||
|
||||
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
|
||||
pipe.hset(redis_key, model_last_used_field, current_time)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str):
|
||||
api_key = api_key_info['api_key']
|
||||
usage_key = f"api_key:{api_key}"
|
||||
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
|
||||
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
|
||||
|
||||
if tokens_used + tokens_required > total_tokens:
|
||||
raise HTTPException(status_code=403, detail="Token 余额不足")
|
||||
|
||||
# 更新 token 使用量
|
||||
await update_token_usage(api_key, tokens_required, model_name)
|
||||
|
||||
# 发送任务到Kafka
|
||||
producer.send(kafka_topic, task_data)
|
||||
|
||||
# 获取更新后的 token 使用情况
|
||||
updated_api_key_info = await verify_api_key(api_key)
|
||||
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
|
||||
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
|
||||
|
||||
return {
|
||||
"message": f"{model_name.upper()}请求已排队等待处理",
|
||||
"tokens_used": tokens_required,
|
||||
"total_tokens_used": new_tokens_used,
|
||||
f"{model_name}_tokens_used": model_tokens_used,
|
||||
"tokens_remaining": total_tokens - new_tokens_used
|
||||
}
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str
|
||||
voice: str = Field(..., description="选择的音色")
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str
|
||||
query: str
|
||||
model: str = "qwen2.5:3b"
|
||||
|
||||
|
||||
@v1_chat_app.post("/tts")
|
||||
async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
text_hash = get_audio_hash(request.text)
|
||||
|
||||
# 验证音色选择
|
||||
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
|
||||
if request.voice not in valid_voices:
|
||||
raise HTTPException(status_code=400, detail="无效的音色选择")
|
||||
|
||||
# 如果声音是 'default',则将其视为 'girl'
|
||||
voice = 'girl' if request.voice == 'default' else request.voice
|
||||
|
||||
# 使用对应音色的 Redis 客户端
|
||||
redis_tts = voice_to_redis[request.voice]
|
||||
|
||||
# 检查是否已存在相同内容的音频文件
|
||||
existing_audio_info = redis_tts.get(f"tts:{text_hash}")
|
||||
if existing_audio_info:
|
||||
existing_audio_path = json.loads(existing_audio_info)['path']
|
||||
if os.path.exists(existing_audio_path):
|
||||
return {
|
||||
"message": "TTS请求已完成",
|
||||
"task_id": task_id,
|
||||
"status": "completed",
|
||||
"audio_path": existing_audio_path
|
||||
}
|
||||
|
||||
# 如果不存在,创建新的任务
|
||||
task_data = {
|
||||
"task_id": task_id,
|
||||
"text": request.text,
|
||||
"text_hash": text_hash,
|
||||
"voice": request.voice,
|
||||
"status": "queued",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# 存储任务信息到Redis
|
||||
redis_task_client.set(f"task_status:tts:{task_id}", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "tts", 1, task_data, KAFKA_TTS_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
return result
|
||||
@v1_chat_app.post("/asr")
|
||||
async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
UPLOAD_DIR = "/obscura/task/audio_upload"
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav")
|
||||
|
||||
with open(file_path, "wb") as temp_audio:
|
||||
content = await audio.read()
|
||||
temp_audio.write(content)
|
||||
|
||||
task_data = {
|
||||
'file_path': file_path,
|
||||
'task_id': task_id,
|
||||
'status': 'queued'
|
||||
}
|
||||
|
||||
# 存储任务状态,使用一致的键名格式
|
||||
redis_task_client.set(f"task_status:asr:{task_id}", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "asr", 1, task_data, KAFKA_ASR_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
return result
|
||||
|
||||
@v1_chat_app.post("/chat")
|
||||
async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_id = str(uuid.uuid4())
|
||||
task_data = {
|
||||
"task_id": task_id,
|
||||
"session_id": request.session_id,
|
||||
"query": request.query,
|
||||
"model": request.model,
|
||||
"status": "queued",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# 设置任务状态为 "queued"
|
||||
redis_task_client.set(f"chat:{task_id}:status", "queued")
|
||||
|
||||
result = await process_request(api_key_info, "chat", 1, task_data, KAFKA_CHAT_TOPIC)
|
||||
result["task_id"] = task_id
|
||||
return result
|
||||
|
||||
|
||||
@v1_chat_app.get("/chat_result/{task_id}")
|
||||
async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
# 从Redis任务数据库获取任务状态
|
||||
task_status = redis_task_client.get(f"chat:{task_id}:status")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从Redis任务数据库获取聊天结果
|
||||
chat_result = redis_task_client.get(f"chat:{task_id}:result")
|
||||
if chat_result:
|
||||
result = json.loads(chat_result)
|
||||
return {
|
||||
"status": "completed",
|
||||
"result": result
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
@v1_chat_app.get("/tts_result/{task_id}")
|
||||
async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
|
||||
if task_info:
|
||||
task_data = json.loads(task_info)
|
||||
text_hash = task_data['text_hash']
|
||||
voice = task_data['voice']
|
||||
# 'default' 和 'girl' 都使用 girl 的 Redis
|
||||
redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice]
|
||||
|
||||
audio_info = redis_tts.get(f"tts:{text_hash}")
|
||||
if audio_info:
|
||||
audio_path = json.loads(audio_info)['path']
|
||||
return {
|
||||
"status": "completed",
|
||||
"audio_path": audio_path
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
@v1_chat_app.get("/asr_result/{task_id}")
|
||||
async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
# 从Redis任务数据库获取任务状态,使用一致的键名格式
|
||||
task_status = redis_task_client.get(f"task_status:asr:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从Redis ASR结果数据库获取转录结果
|
||||
transcription = redis_asr_client.get(f"asr:{task_id}")
|
||||
if transcription:
|
||||
return {
|
||||
"status": "completed",
|
||||
"transcription": transcription.decode('utf-8')
|
||||
}
|
||||
return {"status": status}
|
||||
|
||||
return {"status": "not_found"}
|
||||
|
||||
@v1_chat_app.get("/tts_audio/{task_id}")
|
||||
async def get_tts_audio(task_id: str, api_key_info: dict = Depends(verify_api_key)):
|
||||
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
|
||||
if task_status:
|
||||
status = task_status.decode('utf-8')
|
||||
if status == "completed":
|
||||
# 从任务信息中获取使用的音色
|
||||
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
|
||||
if task_info:
|
||||
task_data = json.loads(task_info)
|
||||
voice = task_data.get('voice', 'girl') # 默认使用 'girl'
|
||||
# 'default' 和 'girl' 都使用 girl 的 Redis
|
||||
redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice]
|
||||
|
||||
# 从对应音色的 Redis 获取音频文件路径
|
||||
audio_info = redis_tts.get(f"tts:{task_data['text_hash']}")
|
||||
if audio_info:
|
||||
audio_path = json.loads(audio_info)['path']
|
||||
if os.path.exists(audio_path):
|
||||
file_name = os.path.basename(audio_path)
|
||||
return FileResponse(audio_path, media_type="audio/wav", filename=file_name)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="音频文件不存在")
|
||||
elif status == "queued" or status == "processing":
|
||||
raise HTTPException(status_code=202, detail="音频文件正在生成中")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="任务处理失败")
|
||||
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
@v1_chat_app.get("/getvoice")
|
||||
async def get_available_voices(api_key_info: dict = Depends(verify_api_key)):
|
||||
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
|
||||
return {"available_voices": valid_voices}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8008)
|
||||
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
fastapi
|
||||
uvicorn
|
||||
pydantic
|
||||
kafka-python
|
||||
redis
|
||||
python-dotenv
|
||||
python-multipart
|
||||
Regular → Executable
Regular → Executable
@@ -0,0 +1 @@
|
||||
{"GPT": {"v1": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", "v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"}, "SoVITS": {"v1": "GPT_SoVITS/pretrained_models/s2G488k.pth", "v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"}}
|
||||
Reference in New Issue
Block a user