update 20250403

This commit is contained in:
2025-04-03 06:21:55 +00:00
parent 0e57b2d02e
commit f557be8b7e
110 changed files with 569 additions and 16526 deletions
+30 -29
View File
@@ -1,28 +1,27 @@
# Kafka 配置
KAFKA_BROKER=222.186.10.253:9092
KAFKA_BROKER=222.186.20.67:9092
KAFKA_ASR_TOPIC=asr
KAFKA_CHAT_TOPIC=chat
KAFKA_TTS_TOPIC=tts
# Redis 配置
REDIS_HOST=150.158.144.159
REDIS_PORT=13003
REDIS_ASR_DB=12
REDIS_CHAT_DB=13
REDIS_TTS_DB=14
REDIS_HOST=222.186.20.67
REDIS_PORT=6379
REDIS_ASR_DB=43
REDIS_CHAT_DB=44
REDIS_TTS_DB=45
REDIS_PASSWORD=Obscura@2024
REDIS_API_DB=2
REDIS_API_USAGE_DB=3
REDIS_TASK_DB=11
REDIS_SESSION_DB=63
REDIS_API_DB=31
REDIS_API_USAGE_DB=32
REDIS_TASK_DB=46
REDIS_SESSION_DB=47
REDIS_SESSION_DB_ZH=48
REDIS_SESSION_DB_EN=49
REDIS_SESSION_DB_KO=50
REDIS_SESSION_DB_ZH=63
REDIS_SESSION_DB_EN=62
REDIS_SESSION_DB_KO=61
# CORS 配置
# ALLOWED_ORIGINS=https://beta.obscura.work
# GPT-SoVITS 配置
@@ -76,19 +75,21 @@ XUZHIYUAN_REF_AUDIO=sample/xuzhiyuan.wav
XUZHIYUAN_REF_TEXT=sample/xuzhiyuan.txt
REDIS_GIRL_DB = 15
REDIS_WOMAN_DB = 16
REDIS_MAN_DB = 17
REDIS_LEIJUN_DB = 18
REDIS_DUFU_DB = 19
REDIS_HEJIONG_DB = 20
REDIS_MAHUATENG_DB = 21
REDIS_LIDAN_DB = 22
REDIS_DABING_DB = 23
REDIS_LUOXIANG_DB = 24
REDIS_XUZHIYUAN_DB = 25
REDIS_YUHUA_DB = 26
REDIS_LIUZHENYUN_DB = 27
REDIS_GIRL_DB = 51
REDIS_WOMAN_DB = 52
REDIS_MAN_DB = 53
REDIS_LEIJUN_DB = 54
REDIS_DUFU_DB = 55
REDIS_HEJIONG_DB = 56
REDIS_MAHUATENG_DB = 57
REDIS_LIDAN_DB = 58
REDIS_DABING_DB = 59
REDIS_LUOXIANG_DB = 60
REDIS_XUZHIYUAN_DB = 61
REDIS_YUHUA_DB = 62
REDIS_LIUZHENYUN_DB = 63
# Ollama API配置 - 多个地址用逗号分隔
OLLAMA_URLS=http://222.186.20.67:11435,http://222.186.20.67:11436,http://222.186.20.67:11437,http://222.186.20.67:11438,http://222.186.20.67:11439,http://222.186.20.67:11440,http://222.186.20.67:11441
OLLAMA_TIMEOUT=10 # API请求超时时间(秒)
Regular → Executable
View File
View File
+4 -1
View File
@@ -6,6 +6,9 @@ from dotenv import load_dotenv
from kafka import KafkaConsumer
import asyncio
# 在导入其他库之前设置
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# 设置要使用的GPU ID
GPU_ID = 1 # 修改这个值来选择要使用的GPU
@@ -16,7 +19,7 @@ os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
load_dotenv()
print("正在加载Whisper模型...")
model = whisper.load_model("large-v3")
model = whisper.load_model("small")
print("Whisper模型加载完成。")
# Kafka配置
-115
View File
@@ -1,115 +0,0 @@
import whisper
import os
import json
import redis
from dotenv import load_dotenv
from kafka import KafkaConsumer
import threading
# 设置要使用的GPU ID
GPU_ID = 1 # 修改这个值来选择要使用的GPU
# 设置CUDA_VISIBLE_DEVICES环境变量
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
# 加载环境变量
load_dotenv()
print("正在加载Whisper模型...")
model = whisper.load_model("large-v3")
print("Whisper模型加载完成。")
# Kafka配置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
# Redis配置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
# 创建Redis客户端
redis_asr_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_ASR_DB,
password=REDIS_PASSWORD
)
redis_task_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_TASK_DB,
password=REDIS_PASSWORD
)
def process_audio(file_path: str, client_id: str, cache_key: str):
try:
# 设置任务状态为 "processing"
redis_task_client.set(f"task_status:{cache_key}", "processing")
result = model.transcribe(file_path)
transcription = result['text']
print(f"处理了文件: {file_path}")
print(f"转录结果: {transcription}")
# 将结果存入Redis缓存
redis_asr_client.setex(cache_key, 3600, transcription) # 缓存1小时
# 发布结果到Redis频道
result_data = {
'client_id': client_id,
'transcription': transcription
}
redis_asr_client.publish('asr_results', json.dumps(result_data))
# 设置任务状态为 "completed"
redis_task_client.set(f"task_status:{cache_key}", "completed")
# 清理临时文件
os.remove(file_path)
except Exception as e:
print(f"处理音频文件时发生错误: {str(e)}")
# 设置任务状态为 "error"
redis_task_client.set(f"task_status:{cache_key}", "error")
def kafka_consumer():
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
group_id='asr_group',
auto_offset_reset='earliest',
enable_auto_commit=True
)
print(f"ASR消费者已启动")
for message in consumer:
try:
task = message.value
file_path = task.get('file_path')
task_id = task.get('task_id')
status = task.get('status')
if not file_path or not task_id or status != 'queued':
print(f"收到无效任务: {task}")
continue
cache_key = f"asr:{task_id}"
client_id = task_id # 使用task_id作为client_id
print(f"开始处理任务: {cache_key}")
process_audio(file_path, client_id, cache_key)
print(f"完成处理任务: {cache_key}")
except Exception as e:
print(f"处理消息时发生错误: {str(e)}")
if __name__ == "__main__":
print("启动Kafka消费者处理ASR请求...")
kafka_consumer()
-117
View File
@@ -1,117 +0,0 @@
from kafka import KafkaConsumer
import json
import asyncio
import redis
import os
from dotenv import load_dotenv
import requests
from concurrent.futures import ThreadPoolExecutor
# 加载 .env 文件
load_dotenv()
# Kafka 设置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
KAFKA_CONSUMER_GROUP = 'chat_group'
KAFKA_CONSUMER_NUM = 3 # 消费者数量
# Redis 设置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
# 创建Redis客户端
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_CHAT_DB,
password=REDIS_PASSWORD
)
redis_task_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_TASK_DB,
password=REDIS_PASSWORD
)
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
# 创建Kafka消费者
def create_kafka_consumer():
return KafkaConsumer(
KAFKA_CHAT_TOPIC,
bootstrap_servers=KAFKA_BROKER,
auto_offset_reset='latest',
enable_auto_commit=True,
group_id=KAFKA_CONSUMER_GROUP,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
async def process_chat_request(chat_request):
try:
session_id = chat_request['session_id']
query = chat_request['query']
model = chat_request.get('model', 'qwen2.5:3b')
# 设置任务状态为 "processing"
redis_task_client.set(f"task_status:{session_id}", "processing")
# 从Redis获取历史记录
history = json.loads(redis_client.get(session_id) or '[]')
# 构建包含历史对话的完整提示
full_prompt = DEFAULT_SYSTEM_PROMPT + "\n"
for past_query, past_response in history:
full_prompt += f"用户: {past_query}\n助手: {past_response}\n"
full_prompt += f"用户: {query}\n助手:"
data = {
"model": model,
"prompt": full_prompt,
"stream": True,
"temperature": 0
}
response = requests.post("http://127.0.0.1:11434/api/generate", json=data, stream=True)
response.raise_for_status()
text_output = ""
for line in response.iter_lines():
if line:
json_data = json.loads(line)
if 'response' in json_data:
text_output += json_data['response']
# 更新历史记录
history.append((query, text_output))
redis_client.set(session_id, json.dumps(history))
# 设置任务状态为 "completed"
redis_task_client.set(f"task_status:{session_id}", "completed")
print(f"处理完成 session {session_id}: {text_output}")
except Exception as e:
print(f"处理 session {chat_request['session_id']} 时出错: {str(e)}")
# 设置任务状态为 "error"
redis_task_client.set(f"task_status:{chat_request['session_id']}", "error")
def kafka_consumer_thread(consumer_id):
consumer = create_kafka_consumer()
print(f"消费者 {consumer_id} 已启动")
for message in consumer:
chat_request = message.value
asyncio.run(process_chat_request(chat_request))
def main():
print("启动Kafka消费者处理聊天请求...")
with ThreadPoolExecutor(max_workers=KAFKA_CONSUMER_NUM) as executor:
for i in range(KAFKA_CONSUMER_NUM):
executor.submit(kafka_consumer_thread, i)
if __name__ == '__main__':
main()
-182
View File
@@ -1,182 +0,0 @@
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import requests
import json
from typing import List, Tuple
from kafka import KafkaConsumer, TopicPartition
from concurrent.futures import ThreadPoolExecutor
import threading
import asyncio
import redis
import uuid
import logging
import uvicorn
from dotenv import load_dotenv
import os
import torch
from modelscope import AutoModelForCausalLM, AutoTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
print(f"Using device: {device}")
# Load MiniCPM3-4B model
path = "/home/zydi/worker_chat/api/OpenBMB/MiniCPM3-4B"
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True)
# 加载 .env 文件
load_dotenv()
# CORS 配置
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Kafka 设置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TOPIC = os.getenv('KAFKA_MINI3_TOPIC')
KAFKA_CONSUMER_GROUP = 'mini3_group'
KAFKA_CONSUMER_NUM = 1
# Redis 设置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_DB = int(os.getenv('REDIS_MINI3_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
# 创建Redis客户端
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_DB,
password=REDIS_PASSWORD # 使用密码进行认证
)
# 创建Kafka消费者
def create_kafka_consumer():
return KafkaConsumer(
bootstrap_servers=KAFKA_BROKER,
auto_offset_reset='earliest',
enable_auto_commit=True,
group_id=KAFKA_CONSUMER_GROUP,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# Kafka消费者函数
def kafka_consumer(consumer, consumer_id):
# 获取消费者分配的分区
consumer.subscribe([KAFKA_TOPIC])
partitions = consumer.assignment()
logger.info(f"消费者 {consumer_id} 被分配了以下分区: {[p.partition for p in partitions]}")
for message in consumer:
partition = message.partition
offset = message.offset
chat_request = message.value # 直接使用 message.value,它已经是一个字典
session_id = chat_request['session_id']
query = chat_request['query']
logger.info(f"消费者 {consumer_id} 正在处理来自分区 {partition} 的消息:")
asyncio.run(process_chat_request(chat_request))
# 启动Kafka消费者线程
def start_kafka_consumers(num_consumers=KAFKA_CONSUMER_NUM):
consumers = []
for i in range(num_consumers):
consumer = create_kafka_consumer()
consumer_thread = threading.Thread(target=kafka_consumer, args=(consumer, i), daemon=True)
consumer_thread.start()
consumers.append((consumer, consumer_thread))
return consumers
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
class ChatRequest(BaseModel):
session_id: str
query: str
model: str = "minicpm3-4b"
class ChatResponse(BaseModel):
response: str
history: List[Tuple[str, str]]
# 处理聊天请求的异步函数
async def process_chat_request(chat_request):
try:
response = await chat(ChatRequest(**chat_request))
print(f"Processed message for session {chat_request['session_id']}: {response}")
except Exception as e:
print(f"Error processing message for session {chat_request['session_id']}: {str(e)}")
@app.post("/mini3", response_model=ChatResponse)
async def chat(request: ChatRequest):
session_id = request.session_id
query = request.query
# 从Redis获取历史记录
history = json.loads(redis_client.get(session_id) or '[]')
# 构建包含历史对话的完整提示
messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
for past_query, past_response in history:
messages.append({"role": "user", "content": past_query})
messages.append({"role": "assistant", "content": past_response})
messages.append({"role": "user", "content": query})
try:
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
# 创建注意力掩码
attention_mask = model_inputs.ne(tokenizer.pad_token_id).long()
# 将输入移动到正确的设备(CPU或GPU)
model_inputs = model_inputs.to(device)
attention_mask = attention_mask.to(device)
model_outputs = model.generate(
model_inputs,
attention_mask=attention_mask,
max_new_tokens=1024,
top_p=0.7,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id, # 将pad_token_id设置为eos_token_id
do_sample=True
)
output_token_ids = model_outputs[0][len(model_inputs[0]):]
text_output = tokenizer.decode(output_token_ids, skip_special_tokens=True)
# 更新历史记录
history.append((query, text_output))
redis_client.set(session_id, json.dumps(history))
return ChatResponse(response=text_output, history=history)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/start_chat")
async def start_chat():
session_id = str(uuid.uuid4())
return {"session_id": session_id}
if __name__ == '__main__':
# 启动Kafka消费者线程
start_kafka_consumers()
# 启动FastAPI服务器
uvicorn.run(app, host="0.0.0.0", port=6003)
-63
View File
@@ -1,63 +0,0 @@
import os
from moviepy.editor import VideoFileClip
def mp4_to_wav(input_file, output_file):
"""
将MP4文件转换为WAV格式
:param input_file: 输入的MP4文件路径
:param output_file: 输出的WAV文件路径
"""
try:
# 加载视频文件
video = VideoFileClip(input_file)
# 提取音频
audio = video.audio
# 将音频写入WAV文件
audio.write_audiofile(output_file)
# 关闭视频和音频对象
audio.close()
video.close()
print(f"转换成功: {input_file} -> {output_file}")
except Exception as e:
print(f"转换失败: {input_file} - {str(e)}")
def process_directory(directory):
"""
处理目录中的所有MP4文件
:param directory: 包含MP4文件的目录路径
"""
for filename in os.listdir(directory):
if filename.lower().endswith('.mp4'):
input_file = os.path.join(directory, filename)
output_file = os.path.splitext(input_file)[0] + ".wav"
mp4_to_wav(input_file, output_file)
def main():
# 获取输入路径
input_path = input("请输入MP4文件或包含MP4文件的目录路径: ").strip()
# 检查输入路径是否存在
if not os.path.exists(input_path):
print("错误: 输入路径不存在")
return
# 判断输入路径是文件还是目录
if os.path.isfile(input_path):
if not input_path.lower().endswith('.mp4'):
print("错误: 输入文件不是MP4格式")
return
output_file = os.path.splitext(input_path)[0] + ".wav"
mp4_to_wav(input_path, output_file)
elif os.path.isdir(input_path):
process_directory(input_path)
else:
print("错误: 输入路径既不是文件也不是目录")
if __name__ == "__main__":
main()
-170
View File
@@ -1,170 +0,0 @@
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import requests
import json
from typing import List, Tuple
from kafka import KafkaConsumer, TopicPartition
from concurrent.futures import ThreadPoolExecutor
import threading
import asyncio
import redis
import uuid
import logging
import uvicorn
from dotenv import load_dotenv
import os
import torch
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
print(f"Using device: {device}")
# 加载 .env 文件
load_dotenv()
# CORS 配置
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Kafka 设置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
KAFKA_CONSUMER_GROUP = 'chat_group'
KAFKA_CONSUMER_NUM = 1
# Redis 设置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_DB = int(os.getenv('REDIS_CHAT_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
# 创建Redis客户端
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_DB,
password=REDIS_PASSWORD # 使用密码进行认证
)
# 创建Kafka消费者
def create_kafka_consumer():
return KafkaConsumer(
bootstrap_servers=KAFKA_BROKER,
auto_offset_reset='earliest',
enable_auto_commit=True,
group_id=KAFKA_CONSUMER_GROUP,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# Kafka消费者函数
def kafka_consumer(consumer, consumer_id):
# 获取消费者分配的分区
consumer.subscribe([KAFKA_TOPIC])
partitions = consumer.assignment()
logger.info(f"消费者 {consumer_id} 被分配了以下分区: {[p.partition for p in partitions]}")
for message in consumer:
partition = message.partition
offset = message.offset
chat_request = message.value # 直接使用 message.value,它已经是一个字典
session_id = chat_request['session_id']
query = chat_request['query']
logger.info(f"消费者 {consumer_id} 正在处理来自分区 {partition} 的消息:")
asyncio.run(process_chat_request(chat_request))
# 启动Kafka消费者线程
def start_kafka_consumers(num_consumers=KAFKA_CONSUMER_NUM):
consumers = []
for i in range(num_consumers):
consumer = create_kafka_consumer()
consumer_thread = threading.Thread(target=kafka_consumer, args=(consumer, i), daemon=True)
consumer_thread.start()
consumers.append((consumer, consumer_thread))
return consumers
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
class ChatRequest(BaseModel):
session_id: str
query: str
model: str = "qwen2.5:3b"
class ChatResponse(BaseModel):
response: str
history: List[Tuple[str, str]]
# 处理聊天请求的异步函数
async def process_chat_request(chat_request):
try:
response = await chat(ChatRequest(**chat_request))
print(f"Processed message for session {chat_request['session_id']}: {response}")
except Exception as e:
print(f"Error processing message for session {chat_request['session_id']}: {str(e)}")
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
session_id = request.session_id
query = request.query
model = request.model
# 从Redis获取历史记录
history = json.loads(redis_client.get(session_id) or '[]')
# 构建包含历史对话的完整提示
full_prompt = DEFAULT_SYSTEM_PROMPT + "\n"
for past_query, past_response in history:
full_prompt += f"用户: {past_query}\n助手: {past_response}\n"
full_prompt += f"用户: {query}"
data = {
"model": model,
"prompt": full_prompt,
"stream": True,
"temperature": 0
}
try:
response = requests.post("http://127.0.0.1:11434/api/generate", json=data, stream=True)
response.raise_for_status()
text_output = ""
for line in response.iter_lines():
if line:
json_data = json.loads(line)
if 'response' in json_data:
text_output += json_data['response']
# 更新历史记录
history.append((query, text_output))
redis_client.set(session_id, json.dumps(history))
return ChatResponse(response=text_output, history=history)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/start_chat")
async def start_chat():
session_id = str(uuid.uuid4())
return {"session_id": session_id}
if __name__ == '__main__':
# 启动Kafka消费者线程
start_kafka_consumers()
# 启动FastAPI服务器
uvicorn.run(app, host="0.0.0.0", port=6001)
-136
View File
@@ -1,136 +0,0 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import httpx
import json
import redis
from typing import List, Dict, Optional
import logging
import ollama
import uuid
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Redis连接
redis_client = redis.Redis(host='222.186.10.253', port=6379, db=14, password="Obscura@2024")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GenerateRequest(BaseModel):
model: Optional[str] = "qwen2.5:3b"
prompt: str
class RawGenerateRequest(BaseModel):
model: Optional[str] = "qwen2.5:3b"
prompt: str
system_prompt: Optional[str] = None
stream: Optional[bool] = False
raw: Optional[bool] = False
format: Optional[str] = None
options: Optional[Dict] = None
class GenerateResponse(BaseModel):
response: dict
request_id: str
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
logger.info(f"收到请求: {request}")
request_id = str(uuid.uuid4())
try:
response = ollama.chat(model=request.model, messages=[{"role": "user", "content": request.prompt}])
full_response = response['message']['content']
request_data = {
"model": request.model,
"prompt": request.prompt,
"response": full_response
}
redis_client.set(f"request:{request_id}", json.dumps(request_data))
response_data = {
"response": full_response,
"model": request.model
}
return GenerateResponse(response=response_data, request_id=request_id)
except Exception as e:
logger.error(f"发生错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/generate")
async def generate_without_history(request: RawGenerateRequest):
"""
处理无历史记录的生成请求。
参数:
- request: RawGenerateRequest对象,包含生成请求的所有参数。
返回:
- 包含生成结果的字典。
"""
try:
response = ollama.generate(
model=request.model,
prompt=request.prompt,
system=request.system_prompt,
format=request.format,
options=request.options,
stream=request.stream
)
response_data = {
"model": request.model,
"response": response['response'],
"done": True,
"context": response.get('context'),
"total_duration": response.get('total_duration'),
"load_duration": response.get('load_duration'),
"prompt_eval_count": response.get('prompt_eval_count'),
"prompt_eval_duration": response.get('prompt_eval_duration'),
"eval_count": response.get('eval_count'),
"eval_duration": response.get('eval_duration')
}
request_id = str(uuid.uuid4())
redis_client.set(f"request:{request_id}", json.dumps(response_data))
return response_data
except Exception as e:
logger.error(f"发生未预期的错误: {e}")
logger.exception("详细错误信息:")
raise HTTPException(status_code=500, detail=f"处理Ollama请求时发生错误: {str(e)}")
@app.get("/request/{request_id}", response_model=Dict)
async def get_request(request_id: str):
request_data = redis_client.get(f"request:{request_id}")
if request_data:
return json.loads(request_data)
raise HTTPException(status_code=404, detail="请求未找到")
@app.get("/models")
async def list_models():
return ollama.list()
@app.get("/models/{model_name}")
async def show_model(model_name: str):
return ollama.show(model_name)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7000)
-319
View File
@@ -1,319 +0,0 @@
from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader
from pydantic import BaseModel
from kafka import KafkaProducer
from redis import Redis
import os
import json
import uuid
from datetime import datetime, timezone
from dotenv import load_dotenv
import tempfile
import hashlib
import asyncio
# 加载 .env 文件
load_dotenv()
app = FastAPI()
v1_chat_app = FastAPI()
app.mount("/v1_chat", v1_chat_app)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 配置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB'))
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
REDIS_API_DB = int(os.getenv('REDIS_API_DB'))
REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB'))
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
# 初始化 Kafka Producer
producer = KafkaProducer(
bootstrap_servers=[KAFKA_BROKER],
value_serializer=lambda v: json.dumps(v).encode('utf-8')
)
# 初始化 Redis
redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB)
redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB)
redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB)
redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB)
redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB)
redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB)
# 定义API密钥头部
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
def get_audio_hash(text):
return hashlib.md5(text.encode()).hexdigest()
# 验证API密钥的函数
async def get_api_key(api_key: str = Security(api_key_header)):
if api_key and api_key.startswith("Bearer "):
key = api_key.split(" ")[1]
if key.startswith("obs-"):
return key
raise HTTPException(
status_code=401,
detail="无效的API密钥",
headers={"WWW-Authenticate": "Bearer"},
)
async def verify_api_key(api_key: str = Depends(get_api_key)):
redis_key = f"api_key:{api_key}"
api_key_info = redis_api_client.hgetall(redis_key)
if not api_key_info:
raise HTTPException(status_code=401, detail="无效的API密钥")
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
if api_key_info.get('is_active') != '1':
raise HTTPException(status_code=401, detail="API密钥已停用")
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
if datetime.now(timezone.utc) > expires_at:
raise HTTPException(status_code=401, detail="API密钥已过期")
usage_info = redis_api_usage_client.hgetall(redis_key)
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
return {
"api_key": api_key,
**api_key_info,
**usage_info
}
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
redis_key = f"api_key:{api_key}"
current_time = datetime.now(timezone.utc).isoformat()
pipe = redis_api_usage_client.pipeline()
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
pipe.hset(redis_key, "last_used_at", current_time)
model_tokens_field = f"{model_name}_tokens_used"
model_last_used_field = f"{model_name}_last_used_at"
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
pipe.hset(redis_key, model_last_used_field, current_time)
pipe.execute()
async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str):
api_key = api_key_info['api_key']
usage_key = f"api_key:{api_key}"
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
if tokens_used + tokens_required > total_tokens:
raise HTTPException(status_code=403, detail="Token 余额不足")
# 更新 token 使用量
await update_token_usage(api_key, tokens_required, model_name)
# 发送任务到Kafka
producer.send(kafka_topic, task_data)
# 获取更新后的 token 使用情况
updated_api_key_info = await verify_api_key(api_key)
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
return {
"message": f"{model_name.upper()}请求已排队等待处理",
"tokens_used": tokens_required,
"total_tokens_used": new_tokens_used,
f"{model_name}_tokens_used": model_tokens_used,
"tokens_remaining": total_tokens - new_tokens_used
}
class TTSRequest(BaseModel):
text: str
class ChatRequest(BaseModel):
session_id: str
query: str
model: str = "qwen2.5:3b"
# 添加WebSocket连接管理
class ConnectionManager:
def __init__(self):
self.active_connections = {}
async def connect(self, websocket: WebSocket, client_id: str):
await websocket.accept()
self.active_connections[client_id] = websocket
def disconnect(self, client_id: str):
self.active_connections.pop(client_id, None)
async def send_message(self, message: str, client_id: str):
if client_id in self.active_connections:
await self.active_connections[client_id].send_text(message)
manager = ConnectionManager()
@v1_chat_app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
await manager.connect(websocket, client_id)
try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
manager.disconnect(client_id)
# 修改TTS请求处理函数
@v1_chat_app.post("/tts")
async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
task_data = {
"task_id": task_id,
"text": request.text,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
}
redis_task_client.set(f"task_status:{task_id}", "queued")
result = await process_request(api_key_info, "tts", 100, task_data, KAFKA_TTS_TOPIC)
result["task_id"] = task_id
# 将任务ID存储到Redis,以便后续WebSocket通信使用
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
return result
# 修改ASR请求处理函数
@v1_chat_app.post("/asr")
async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
UPLOAD_DIR = "/obscura/task/audio_upload"
os.makedirs(UPLOAD_DIR, exist_ok=True)
file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav")
with open(file_path, "wb") as temp_audio:
content = await audio.read()
temp_audio.write(content)
task_data = {
'file_path': file_path,
'task_id': task_id,
'status': 'queued'
}
redis_task_client.set(f"task_status:{task_id}", "queued")
result = await process_request(api_key_info, "asr", 100, task_data, KAFKA_ASR_TOPIC)
result["task_id"] = task_id
# 将任务ID存储到Redis,以便后续WebSocket通信使用
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
return result
# 修改聊天请求处理函数
@v1_chat_app.post("/chat")
async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
task_data = {
"task_id": task_id,
"session_id": request.session_id,
"query": request.query,
"model": request.model,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
}
redis_task_client.set(f"task_status:{task_id}", "queued")
result = await process_request(api_key_info, "chat", 100, task_data, KAFKA_CHAT_TOPIC)
result["task_id"] = task_id
# 将任务ID存储到Redis,以便后续WebSocket通信使用
redis_task_client.set(f"task_client:{task_id}", api_key_info['api_key'])
return result
@v1_chat_app.get("/chat_result/{task_id}")
async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
# 从Redis任务数据库获取任务状态
task_status = redis_task_client.get(f"task_status:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从Redis聊天结果数据库获取聊天结果
chat_result = redis_chat_client.get(task_id)
if chat_result:
result = json.loads(chat_result)
return {
"status": "completed",
"history": result # 直接返回整个历史记录
}
return {"status": status}
return {"status": "not_found"}
@v1_chat_app.get("/tts_result/{task_id}")
async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
# 从Redis任务数据库获取任务状态
task_status = redis_task_client.get(f"task_status:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从Redis TTS结果数据库获取音频文件路径
audio_info = redis_tts_client.get(task_id)
if audio_info:
audio_path = json.loads(audio_info)['path']
return {
"status": "completed",
"audio_path": audio_path
}
return {"status": status}
return {"status": "not_found"}
@v1_chat_app.get("/asr_result/{task_id}")
async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
# 从Redis任务数据库获取任务状态
task_status = redis_task_client.get(f"task_status:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从Redis ASR结果数据库获取转录结果
transcription = redis_asr_client.get(task_id)
if transcription:
return {
"status": "completed",
"transcription": transcription.decode('utf-8')
}
return {"status": status}
return {"status": "not_found"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8008)
-180
View File
@@ -1,180 +0,0 @@
import os
import soundfile as sf
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
import uvicorn
import redis
import hashlib
import json
from kafka import KafkaProducer, KafkaConsumer
import threading
import time
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
from dotenv import load_dotenv
import os
import torch
# 加载 .env 文件
load_dotenv()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# FastAPI configuration
app = FastAPI()
i18n = I18nAuto()
# CORS configuration
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
# Redis configuration
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_DB = int(os.getenv('REDIS_TTS_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
# Kafka configuration
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
# KAFKA_GROUP_ID = 'tts_group'
KAFKA_CONSUMER_THREADS = 1
# TTS configuration
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
REF_AUDIO_PATH = os.getenv('REF_AUDIO_PATH')
REF_TEXT_PATH = os.getenv('REF_TEXT_PATH')
REF_LANGUAGE = os.getenv('REF_LANGUAGE')
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
# Initialize FastAPI CORS
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize Redis client
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_DB,
password=REDIS_PASSWORD
)
# Initialize Kafka producer
kafka_producer = KafkaProducer(bootstrap_servers=KAFKA_BROKER)
class TTSRequest(BaseModel):
text: str = Field(..., alias="text")
def get_audio_hash(text):
return hashlib.md5(text.encode()).hexdigest()
# Initialize models at startup
print("Initializing models...")
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
# Read reference text
with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file:
ref_text = file.read()
print("Models initialized successfully.")
def synthesize(target_text, output_path):
# Synthesize audio
with torch.cuda.device(device):
synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH,
prompt_text=ref_text,
prompt_language=i18n(REF_LANGUAGE),
text=target_text,
text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1)
result_list = list(synthesis_result)
if result_list:
last_sampling_rate, last_audio_data = result_list[-1]
audio_hash = get_audio_hash(target_text)
output_wav_path = os.path.join(output_path, f"{audio_hash}.wav")
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
return output_wav_path
else:
return None
@app.post("/tts")
async def synthesize_audio(request: TTSRequest):
try:
print(f"Received TTS request: {request.dict()}")
target_text = request.text
audio_hash = get_audio_hash(target_text)
# Check Redis cache
cached_audio = redis_client.get(audio_hash)
if cached_audio:
audio_info = json.loads(cached_audio)
return FileResponse(audio_info['path'], media_type="audio/wav")
# Check file system
file_path = os.path.join(OUTPUT_PATH, f"{audio_hash}.wav")
if os.path.exists(file_path):
# Cache the file path in Redis
redis_client.set(audio_hash, json.dumps({"path": file_path}))
return FileResponse(file_path, media_type="audio/wav")
# Send message to Kafka
kafka_producer.send(KAFKA_TOPIC, json.dumps({
'text': target_text,
'audio_hash': audio_hash
}).encode('utf-8'))
# Wait for the audio to be generated (you might want to implement a more sophisticated waiting mechanism)
for _ in range(60): # Wait for up to 30 seconds
if os.path.exists(file_path):
return FileResponse(file_path, media_type="audio/wav")
time.sleep(1)
# If audio is not generated within the timeout
raise HTTPException(status_code=504, detail="Audio generation timed out")
except Exception as e:
print(f"Error processing TTS request: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "TTS API is running"}
def kafka_consumer_thread():
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=KAFKA_BROKER,
# group_id=KAFKA_GROUP_ID,
auto_offset_reset='latest',
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
)
for message in consumer:
target_text = message.value['text']
audio_hash = message.value['audio_hash']
output_path = synthesize(target_text, OUTPUT_PATH)
if output_path:
redis_client.set(audio_hash, json.dumps({"path": output_path}))
print(f"Audio synthesized successfully: {output_path}")
else:
print("Failed to synthesize audio")
if __name__ == "__main__":
# Start Kafka consumer threads
torch.cuda.set_device(device)
for _ in range(KAFKA_CONSUMER_THREADS):
consumer_thread = threading.Thread(target=kafka_consumer_thread)
consumer_thread.start()
uvicorn.run(app, host="0.0.0.0", port=6002)
-180
View File
@@ -1,180 +0,0 @@
import os
import soundfile as sf
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
import uvicorn
import redis
import hashlib
import json
from kafka import KafkaProducer, KafkaConsumer
import threading
import time
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
from dotenv import load_dotenv
import os
import torch
# 加载 .env 文件
load_dotenv()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# FastAPI configuration
app = FastAPI()
i18n = I18nAuto()
# CORS configuration
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
# Redis configuration
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_DB = int(os.getenv('REDIS_TTS_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
# Kafka configuration
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
# KAFKA_GROUP_ID = 'tts_group'
KAFKA_CONSUMER_THREADS = 1
# TTS configuration
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
REF_AUDIO_PATH = os.getenv('REF_AUDIO_KO_PATH')
REF_TEXT_PATH = os.getenv('REF_TEXT_KO_PATH')
REF_LANGUAGE = os.getenv('REF_KO_LANGUAGE')
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
# Initialize FastAPI CORS
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize Redis client
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_DB,
password=REDIS_PASSWORD
)
# Initialize Kafka producer
kafka_producer = KafkaProducer(bootstrap_servers=KAFKA_BROKER)
class TTSRequest(BaseModel):
text: str = Field(..., alias="text")
def get_audio_hash(text):
return hashlib.md5(text.encode()).hexdigest()
# Initialize models at startup
print("Initializing models...")
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
# Read reference text
with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file:
ref_text = file.read()
print("Models initialized successfully.")
def synthesize(target_text, output_path):
# Synthesize audio
with torch.cuda.device(device):
synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH,
prompt_text=ref_text,
prompt_language=i18n(REF_LANGUAGE),
text=target_text,
text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1)
result_list = list(synthesis_result)
if result_list:
last_sampling_rate, last_audio_data = result_list[-1]
audio_hash = get_audio_hash(target_text)
output_wav_path = os.path.join(output_path, f"{audio_hash}.wav")
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
return output_wav_path
else:
return None
@app.post("/tts_ko")
async def synthesize_audio(request: TTSRequest):
try:
print(f"Received TTS request: {request.dict()}")
target_text = request.text
audio_hash = get_audio_hash(target_text)
# Check Redis cache
cached_audio = redis_client.get(audio_hash)
if cached_audio:
audio_info = json.loads(cached_audio)
return FileResponse(audio_info['path'], media_type="audio/wav")
# Check file system
file_path = os.path.join(OUTPUT_PATH, f"{audio_hash}.wav")
if os.path.exists(file_path):
# Cache the file path in Redis
redis_client.set(audio_hash, json.dumps({"path": file_path}))
return FileResponse(file_path, media_type="audio/wav")
# Send message to Kafka
kafka_producer.send(KAFKA_TOPIC, json.dumps({
'text': target_text,
'audio_hash': audio_hash
}).encode('utf-8'))
# Wait for the audio to be generated (you might want to implement a more sophisticated waiting mechanism)
for _ in range(60): # Wait for up to 30 seconds
if os.path.exists(file_path):
return FileResponse(file_path, media_type="audio/wav")
time.sleep(1)
# If audio is not generated within the timeout
raise HTTPException(status_code=504, detail="Audio generation timed out")
except Exception as e:
print(f"Error processing TTS request: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "TTS API is running"}
def kafka_consumer_thread():
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=KAFKA_BROKER,
# group_id=KAFKA_GROUP_ID,
auto_offset_reset='latest',
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
)
for message in consumer:
target_text = message.value['text']
audio_hash = message.value['audio_hash']
output_path = synthesize(target_text, OUTPUT_PATH)
if output_path:
redis_client.set(audio_hash, json.dumps({"path": output_path}))
print(f"Audio synthesized successfully: {output_path}")
else:
print("Failed to synthesize audio")
if __name__ == "__main__":
# Start Kafka consumer threads
torch.cuda.set_device(device)
for _ in range(KAFKA_CONSUMER_THREADS):
consumer_thread = threading.Thread(target=kafka_consumer_thread)
consumer_thread.start()
uvicorn.run(app, host="0.0.0.0", port=6003)
-176
View File
@@ -1,176 +0,0 @@
# 导入所需的库
import os
import soundfile as sf
import redis
import hashlib
import json
from kafka import KafkaConsumer
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
from dotenv import load_dotenv
import torch
"""
整体设计说明:
这个脚本实现了一个文本到语音(TTS)的服务。它使用Kafka作为消息队列接收TTS任务,
使用Redis存储任务状态和结果,并利用GPT-SoVITS模型进行语音合成。
主要功能包括:
1. 初始化配置和模型
2. 提供语音合成功能
3. 监听Kafka消息并处理TTS任务
4. 将合成结果存储到Redis并更新任务状态
"""
# 加载环境变量
load_dotenv()
# 设置GPU设备(如果可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 从环境变量中读取Redis配置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB')) # DB 2用于存储TTS结果
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) # DB 3用于存储任务状态
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
# 从环境变量中读取Kafka配置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
# 从环境变量中读取TTS相关配置
GPT_MODEL_PATH = os.getenv('GPT_MODEL_PATH')
SOVITS_MODEL_PATH = os.getenv('SOVITS_MODEL_PATH')
REF_AUDIO_PATH = os.getenv('REF_AUDIO_ZN_PATH')
REF_TEXT_PATH = os.getenv('REF_TEXT_ZN_PATH')
REF_LANGUAGE = os.getenv('REF_LANGUAGE')
TARGET_LANGUAGE = os.getenv('TARGET_LANGUAGE')
OUTPUT_PATH = os.getenv('OUTPUT_PATH')
# 初始化Redis客户端
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_TTS_DB,
password=REDIS_PASSWORD
)
redis_task_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_TASK_DB,
password=REDIS_PASSWORD
)
# 初始化国际化工具
i18n = I18nAuto()
def get_audio_hash(text):
"""
生成文本的MD5哈希值,用作音频文件名的一部分
参数:
text (str): 需要生成哈希的文本
返回:
str: 文本的MD5哈希值
"""
return hashlib.md5(text.encode()).hexdigest()
# 初始化模型
print("正在初始化模型...")
change_gpt_weights(gpt_path=GPT_MODEL_PATH)
change_sovits_weights(sovits_path=SOVITS_MODEL_PATH)
# 读取参考文本
with open(REF_TEXT_PATH, 'r', encoding='utf-8') as file:
ref_text = file.read()
print("模型初始化成功。")
def synthesize(target_text, output_wav_path):
"""
使用GPT-SoVITS模型合成语音
参数:
target_text (str): 需要合成语音的目标文本
output_wav_path (str): 输出音频文件的路径
返回:
str: 如果成功,返回输出音频文件的路径;如果失败,返回None
"""
with torch.cuda.device(device):
synthesis_result = get_tts_wav(ref_wav_path=REF_AUDIO_PATH,
prompt_text=ref_text,
prompt_language=i18n(REF_LANGUAGE),
text=target_text,
text_language=i18n(TARGET_LANGUAGE), top_p=1, temperature=1)
result_list = list(synthesis_result)
if result_list:
last_sampling_rate, last_audio_data = result_list[-1]
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
return output_wav_path
else:
return None
def kafka_consumer():
"""
Kafka消费者函数,用于接收和处理TTS任务
该函数会持续监听Kafka的TTS主题,接收任务并进行处理:
1. 接收任务信息
2. 更新任务状态
3. 调用synthesize函数合成语音
4. 将结果保存到Redis
5. 更新任务完成状态
"""
consumer = KafkaConsumer(
KAFKA_TTS_TOPIC,
bootstrap_servers=KAFKA_BROKER,
auto_offset_reset='latest',
value_deserializer=lambda m: json.loads(m.decode('utf-8'))
)
print(f"TTS消费者已启动")
for message in consumer:
try:
task_id = message.value['task_id']
target_text = message.value['text']
text_hash = message.value['text_hash']
# 更新任务状态为 "processing"
redis_task_client.set(f"task_status:tts:{task_id}", "processing")
output_wav_path = os.path.join(OUTPUT_PATH, f"{text_hash}.wav")
# 再次检查文件是否存在(以防在此期间被其他进程创建)
if not os.path.exists(output_wav_path):
output_path = synthesize(target_text, output_wav_path)
else:
output_path = output_wav_path
if output_path:
# 将结果保存在 DB 2
redis_client.set(f"tts:{task_id}", json.dumps({"path": output_path}))
print(f"音频合成成功: {output_path}")
# 更新任务状态为 "completed"
redis_task_client.set(f"task_status:tts:{task_id}", "completed")
else:
print("音频合成失败")
# 更新任务状态为 "failed"
redis_task_client.set(f"task_status:tts:{task_id}", "failed")
except Exception as e:
print(f"处理消息时出错: {str(e)}")
# 更新任务状态为 "failed"
redis_task_client.set(f"task_status:tts:{task_id}", "failed")
if __name__ == "__main__":
# 设置CUDA设备
torch.cuda.set_device(device)
# 启动Kafka消费者
kafka_consumer()
-68
View File
@@ -1,68 +0,0 @@
import os
import whisper
import argparse
def transcribe_audio(model, audio_path):
"""
使用Whisper模型转录音频文件
:param model: 加载的Whisper模型
:param audio_path: 音频文件路径
:return: 转录的文本
"""
try:
result = model.transcribe(audio_path)
return result["text"]
except Exception as e:
print(f"转录失败 {audio_path}: {str(e)}")
return None
def process_directory(directory, model):
"""
处理目录中的所有WAV文件
:param directory: 包含WAV文件的目录路径
:param model: 加载的Whisper模型
"""
for filename in os.listdir(directory):
if filename.lower().endswith('.wav'):
input_file = os.path.join(directory, filename)
output_file = os.path.splitext(input_file)[0] + ".txt"
print(f"正在处理: {input_file}")
transcription = transcribe_audio(model, input_file)
if transcription:
with open(output_file, 'w', encoding='utf-8') as f:
f.write(transcription)
print(f"转录完成: {output_file}")
else:
print(f"转录失败: {input_file}")
def main():
parser = argparse.ArgumentParser(description="使用Whisper将WAV文件转换为文本")
parser.add_argument("input_path", help="输入的WAV文件或包含WAV文件的目录路径")
parser.add_argument("--model", default="small", choices=["tiny", "base", "small", "medium", "large", "large-v3"], help="Whisper模型大小")
args = parser.parse_args()
print(f"正在加载Whisper模型 ({args.model})...")
model = whisper.load_model(args.model)
print("模型加载完成")
if os.path.isfile(args.input_path):
if not args.input_path.lower().endswith('.wav'):
print("错误: 输入文件不是WAV格式")
return
output_file = os.path.splitext(args.input_path)[0] + ".txt"
transcription = transcribe_audio(model, args.input_path)
if transcription:
with open(output_file, 'w', encoding='utf-8') as f:
f.write(transcription)
print(f"转录完成: {output_file}")
elif os.path.isdir(args.input_path):
process_directory(args.input_path, model)
else:
print("错误: 输入路径既不是文件也不是目录")
if __name__ == "__main__":
main()
-1
View File
@@ -1 +0,0 @@
{"GPT": {"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"}, "SoVITS": {"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"}}
-186
View File
@@ -1,186 +0,0 @@
from fastapi import FastAPI, File, UploadFile, HTTPException, WebSocket
from fastapi.middleware.cors import CORSMiddleware
import whisper
import tempfile
import uvicorn
from kafka import KafkaProducer, KafkaConsumer
from starlette.websockets import WebSocketDisconnect
import json
import threading
import os
import uuid
import asyncio
import logging
import redis
from dotenv import load_dotenv
# 设置要使用的GPU ID
GPU_ID = 1 # 修改这个值来选择要使用的GPU
# 设置CUDA_VISIBLE_DEVICES环境变量
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
load_dotenv()
# CORS 配置
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
print("正在加载Whisper模型...")
model = whisper.load_model("large-v3")
print("Whisper模型加载完成。")
# Kafka配置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
# Redis配置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_DB = int(os.getenv('REDIS_ASR_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
# 创建Redis客户端
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_DB,
password=REDIS_PASSWORD # 添加密码
)
# Kafka生产者
producer = KafkaProducer(
bootstrap_servers=[KAFKA_BROKER],
value_serializer=lambda v: json.dumps(v).encode('utf-8')
)
# 存储WebSocket连接的字典
active_connections = {}
@app.websocket("/asr/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
await websocket.accept()
active_connections[client_id] = websocket
try:
while True:
try:
# 设置接收超时
data = await asyncio.wait_for(websocket.receive_text(), timeout=30)
if data == "ping":
await websocket.send_text("pong")
else:
await websocket.send_text(f"收到消息: {data}")
except asyncio.TimeoutError:
try:
# 发送心跳
await websocket.send_text("heartbeat")
except WebSocketDisconnect:
logger.info(f"客户端 {client_id} 断开连接")
break
except WebSocketDisconnect:
logger.info(f"客户端 {client_id} 断开连接")
except Exception as e:
logger.error(f"WebSocket错误: {e}")
finally:
if client_id in active_connections:
del active_connections[client_id]
@app.post("/asr")
async def transcribe(audio: UploadFile = File(...)):
if not audio:
raise HTTPException(status_code=400, detail="未提供音频文件")
client_id = str(uuid.uuid4())
# 生成缓存键
cache_key = f"asr:{audio.filename}:{client_id}"
# 检查缓存
cached_result = redis_client.get(cache_key)
if cached_result:
logger.info(f"缓存命中: {cache_key}")
return {"message": "从缓存获取转录结果", "transcription": cached_result.decode('utf-8'), "client_id": client_id}
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio:
content = await audio.read()
temp_audio.write(content)
temp_audio.flush()
task = {
'file_path': temp_audio.name,
'client_id': client_id,
'cache_key': cache_key
}
producer.send(KAFKA_TOPIC, value=task)
producer.flush()
logger.info(f"发送任务到Kafka: {task}")
return {"message": "音频文件已接收并发送任务进行处理", "client_id": client_id}
async def send_transcription(client_id: str, transcription: str):
if client_id in active_connections:
websocket = active_connections[client_id]
await websocket.send_json({"transcription": transcription})
else:
logger.warning(f"客户端 {client_id} 的WebSocket连接不存在")
def kafka_consumer(consumer_id):
consumer = KafkaConsumer(
KAFKA_TOPIC,
bootstrap_servers=[KAFKA_BROKER],
value_deserializer=lambda x: json.loads(x.decode('utf-8')),
group_id='asr_group',
max_poll_interval_ms=300000
)
for message in consumer:
try:
task = message.value
file_path = task.get('file_path')
client_id = task.get('client_id')
cache_key = task.get('cache_key')
if not file_path or not client_id or not cache_key:
logger.error(f"消费者 {consumer_id} 收到无效任务: {task}")
consumer.commit()
continue
result = model.transcribe(file_path)
logger.info(f"消费者 {consumer_id} 处理了文件: {file_path}")
logger.info(f"转录结果: {result['text']}")
# 将结果存入Redis缓存
redis_client.setex(cache_key, 3600, result['text']) # 缓存1小时
asyncio.run(send_transcription(client_id, result['text']))
os.remove(file_path)
consumer.commit()
except Exception as e:
logger.error(f"消费者 {consumer_id} 处理消息时发生错误: {str(e)}")
def start_consumers(num_consumers=1):
for i in range(num_consumers):
consumer_thread = threading.Thread(target=kafka_consumer, args=(i,))
consumer_thread.start()
if __name__ == '__main__':
start_consumers()
uvicorn.run(app, host="0.0.0.0", port=6000)
+30 -2
View File
@@ -23,6 +23,10 @@ REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
OLLAMA_URL = os.getenv('OLLAMA_URL')
OLLAMA_URLS = os.getenv('OLLAMA_URLS', OLLAMA_URL).split(',') # 兼容旧配置
OLLAMA_TIMEOUT = int(os.getenv('OLLAMA_TIMEOUT', 10))
# 创建Redis客户端
redis_client = redis.Redis(
host=REDIS_HOST,
@@ -51,6 +55,21 @@ def create_kafka_consumer():
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
async def try_ollama_request(url, data):
"""尝试向单个 Ollama API 发送请求"""
try:
response = requests.post(
f"{url}/api/generate",
json=data,
stream=True,
timeout=OLLAMA_TIMEOUT
)
response.raise_for_status()
return response
except Exception as e:
print(f"API {url} 请求失败: {str(e)}")
return None
async def process_chat_request(chat_request):
try:
task_id = chat_request['task_id']
@@ -77,9 +96,17 @@ async def process_chat_request(chat_request):
"temperature": 0
}
response = requests.post("https://ffgregevrdcfyhtnhyudvr.myfastools.com/api/generate", json=data, stream=True)
response.raise_for_status()
# 尝试所有可用的 API 地址
response = None
for url in OLLAMA_URLS:
response = await try_ollama_request(url, data)
if response is not None:
print(f"使用 API 地址: {url}")
break
if response is None:
raise Exception("所有 API 地址均不可用")
text_output = ""
for line in response.iter_lines():
if line:
@@ -105,6 +132,7 @@ async def process_chat_request(chat_request):
# 设置任务状态为 "error"
redis_task_client.set(f"chat:{task_id}:status", "error")
redis_task_client.set(f"chat:{task_id}:error", str(e))
def kafka_consumer_thread(consumer_id):
consumer = create_kafka_consumer()
print(f"消费者 {consumer_id} 已启动")
Regular → Executable
View File
-94
View File
@@ -1,94 +0,0 @@
# Kafka 配置
KAFKA_BROKER=222.186.10.253:9092
KAFKA_ASR_TOPIC=asr
KAFKA_CHAT_TOPIC=chat
KAFKA_TTS_TOPIC=tts
# Redis 配置
REDIS_HOST=150.158.144.159
REDIS_PORT=13003
REDIS_ASR_DB=12
REDIS_CHAT_DB=13
REDIS_TTS_DB=14
REDIS_PASSWORD=Obscura@2024
REDIS_API_DB=2
REDIS_API_USAGE_DB=3
REDIS_TASK_DB=11
REDIS_SESSION_DB=63
REDIS_SESSION_DB_ZH=63
REDIS_SESSION_DB_EN=62
REDIS_SESSION_DB_KO=61
# CORS 配置
# ALLOWED_ORIGINS=https://beta.obscura.work
# GPT-SoVITS 配置
GPT_MODEL_PATH=GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
SOVITS_MODEL_PATH=GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
REF_AUDIO_PATH=sample/woman.wav
REF_TEXT_PATH=sample/woman.txt
REF_LANGUAGE=中文
TARGET_LANGUAGE=多语种混合
OUTPUT_PATH=/obscura/task/audio_files
# VOICE_CONFIGS
GIRL_REF_AUDIO=sample/gril.wav
GIRL_REF_TEXT=sample/gril.txt
WOMAN_REF_AUDIO=sample/woman.wav
WOMAN_REF_TEXT=sample/woman.txt
MAN_REF_AUDIO=sample/man.wav
MAN_REF_TEXT=sample/man.txt
LEIJUN_REF_AUDIO=sample/leijun.wav
LEIJUN_REF_TEXT=sample/leijun.txt
DUFU_REF_AUDIO=sample/dufu.wav
DUFU_REF_TEXT=sample/dufu.txt
HEJIONG_REF_AUDIO=sample/hejiong.wav
HEJIONG_REF_TEXT=sample/hejiong.txt
MAHUATENG_REF_AUDIO=sample/mahuateng.wav
MAHUATENG_REF_TEXT=sample/mahuateng.txt
LIDAN_REF_AUDIO=sample/lidan.wav
LIDAN_REF_TEXT=sample/lidan.txt
YUHUA_REF_AUDIO=sample/yuhua.wav
YUHUA_REF_TEXT=sample/yuhua.txt
LIUZHENYUN_REF_AUDIO=sample/liuzhenyun.wav
LIUZHENYUN_REF_TEXT=sample/liuzhenyun.txt
DABING_REF_AUDIO=sample/dabing.wav
DABING_REF_TEXT=sample/dabing.txt
LUOXIANG_REF_AUDIO=sample/luoxiang.wav
LUOXIANG_REF_TEXT=sample/luoxiang.txt
XUZHIYUAN_REF_AUDIO=sample/xuzhiyuan.wav
XUZHIYUAN_REF_TEXT=sample/xuzhiyuan.txt
REDIS_GIRL_DB = 15
REDIS_WOMAN_DB = 16
REDIS_MAN_DB = 17
REDIS_LEIJUN_DB = 18
REDIS_DUFU_DB = 19
REDIS_HEJIONG_DB = 20
REDIS_MAHUATENG_DB = 21
REDIS_LIDAN_DB = 22
REDIS_DABING_DB = 23
REDIS_LUOXIANG_DB = 24
REDIS_XUZHIYUAN_DB = 25
REDIS_YUHUA_DB = 26
REDIS_LIUZHENYUN_DB = 27
-406
View File
@@ -1,406 +0,0 @@
from fastapi import FastAPI, HTTPException, Depends, Security, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader
from fastapi.responses import FileResponse
from pydantic import BaseModel
from kafka import KafkaProducer
from redis import Redis
import os
import json
import uuid
from datetime import datetime, timezone
from dotenv import load_dotenv
import tempfile
import hashlib
from pydantic import BaseModel, Field
# 在文件顶部添加这个函数
def get_audio_hash(text):
return hashlib.md5(text.encode()).hexdigest()
# 加载 .env 文件
load_dotenv()
app = FastAPI()
v1_chat_app = FastAPI()
app.mount("/v1_chat", v1_chat_app)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 配置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
REDIS_TTS_DB = int(os.getenv('REDIS_TTS_DB'))
REDIS_ASR_DB = int(os.getenv('REDIS_ASR_DB'))
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
REDIS_API_DB = int(os.getenv('REDIS_API_DB'))
REDIS_API_USAGE_DB = int(os.getenv('REDIS_API_USAGE_DB'))
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
# Redis 配置
REDIS_GIRL_DB = int(os.getenv('REDIS_GIRL_DB'))
REDIS_WOMAN_DB = int(os.getenv('REDIS_WOMAN_DB'))
REDIS_MAN_DB = int(os.getenv('REDIS_MAN_DB'))
REDIS_LEIJUN_DB = int(os.getenv('REDIS_LEIJUN_DB'))
REDIS_DUFU_DB = int(os.getenv('REDIS_DUFU_DB'))
REDIS_HEJIONG_DB = int(os.getenv('REDIS_HEJIONG_DB'))
REDIS_MAHUATENG_DB = int(os.getenv('REDIS_MAHUATENG_DB'))
REDIS_LIDAN_DB = int(os.getenv('REDIS_LIDAN_DB'))
REDIS_YUHUA_DB = int(os.getenv('REDIS_YUHUA_DB'))
REDIS_LIUZHENYUN_DB = int(os.getenv('REDIS_LIUZHENYUN_DB'))
REDIS_DABING_DB = int(os.getenv('REDIS_DABING_DB'))
REDIS_LUOXIANG_DB = int(os.getenv('REDIS_LUOXIANG_DB'))
REDIS_XUZHIYUAN_DB = int(os.getenv('REDIS_XUZHIYUAN_DB'))
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
KAFKA_ASR_TOPIC = os.getenv('KAFKA_ASR_TOPIC')
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
OUTPUT_PATH= os.getenv('OUTPUT_PATH')
# 初始化 Kafka Producer
producer = KafkaProducer(
bootstrap_servers=[KAFKA_BROKER],
value_serializer=lambda v: json.dumps(v).encode('utf-8')
)
# 初始化 Redis
redis_tts_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TTS_DB)
redis_asr_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_ASR_DB)
redis_chat_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_CHAT_DB)
redis_api_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_DB)
redis_api_usage_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_API_USAGE_DB)
redis_task_client = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_TASK_DB)
redis_tts_girl = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_GIRL_DB)
redis_tts_woman = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_WOMAN_DB)
redis_tts_man = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAN_DB)
redis_tts_leijun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LEIJUN_DB)
redis_tts_dufu = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DUFU_DB)
redis_tts_hejiong = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_HEJIONG_DB)
redis_tts_mahuateng = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_MAHUATENG_DB)
redis_tts_lidan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIDAN_DB)
redis_tts_yuhua = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_YUHUA_DB)
redis_tts_liuzhenyun = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LIUZHENYUN_DB)
redis_tts_dabing = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_DABING_DB)
redis_tts_luoxiang = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_LUOXIANG_DB)
redis_tts_xuzhiyuan = Redis(host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD, db=REDIS_XUZHIYUAN_DB)
# 创建一个音色到对应 Redis 客户端的映射
voice_to_redis = {
"default": redis_tts_girl,
"girl": redis_tts_girl,
"woman": redis_tts_woman,
"man": redis_tts_man,
"leijun": redis_tts_leijun,
"dufu": redis_tts_dufu,
"hejiong": redis_tts_hejiong,
"mahuateng": redis_tts_mahuateng,
"lidan": redis_tts_lidan,
"yuhua": redis_tts_yuhua,
"liuzhenyun": redis_tts_liuzhenyun,
"dabing": redis_tts_dabing,
"luoxiang": redis_tts_luoxiang,
"xuzhiyuan": redis_tts_xuzhiyuan
}
# 定义API密钥头部
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
def get_audio_hash(text):
return hashlib.md5(text.encode()).hexdigest()
# 验证API密钥的函数
async def get_api_key(api_key: str = Security(api_key_header)):
if api_key and api_key.startswith("Bearer "):
key = api_key.split(" ")[1]
if key.startswith("obs-"):
return key
raise HTTPException(
status_code=401,
detail="无效的API密钥",
headers={"WWW-Authenticate": "Bearer"},
)
async def verify_api_key(api_key: str = Depends(get_api_key)):
redis_key = f"api_key:{api_key}"
api_key_info = redis_api_client.hgetall(redis_key)
if not api_key_info:
raise HTTPException(status_code=401, detail="无效的API密钥")
api_key_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in api_key_info.items()}
if api_key_info.get('is_active') != '1':
raise HTTPException(status_code=401, detail="API密钥已停用")
expires_at = datetime.fromisoformat(api_key_info.get('expires_at'))
if datetime.now(timezone.utc) > expires_at:
raise HTTPException(status_code=401, detail="API密钥已过期")
usage_info = redis_api_usage_client.hgetall(redis_key)
usage_info = {k.decode('utf-8'): v.decode('utf-8') for k, v in usage_info.items()}
return {
"api_key": api_key,
**api_key_info,
**usage_info
}
async def update_token_usage(api_key: str, new_tokens_used: int, model_name: str):
redis_key = f"api_key:{api_key}"
current_time = datetime.now(timezone.utc).isoformat()
pipe = redis_api_usage_client.pipeline()
pipe.hincrby(redis_key, "tokens_used", new_tokens_used)
pipe.hset(redis_key, "last_used_at", current_time)
model_tokens_field = f"{model_name}_tokens_used"
model_last_used_field = f"{model_name}_last_used_at"
pipe.hincrby(redis_key, model_tokens_field, new_tokens_used)
pipe.hset(redis_key, model_last_used_field, current_time)
pipe.execute()
async def process_request(api_key_info: dict, model_name: str, tokens_required: int, task_data: dict, kafka_topic: str):
api_key = api_key_info['api_key']
usage_key = f"api_key:{api_key}"
total_tokens = int(redis_api_usage_client.hget(usage_key, "total_tokens") or 0)
tokens_used = int(redis_api_usage_client.hget(usage_key, "tokens_used") or 0)
if tokens_used + tokens_required > total_tokens:
raise HTTPException(status_code=403, detail="Token 余额不足")
# 更新 token 使用量
await update_token_usage(api_key, tokens_required, model_name)
# 发送任务到Kafka
producer.send(kafka_topic, task_data)
# 获取更新后的 token 使用情况
updated_api_key_info = await verify_api_key(api_key)
new_tokens_used = int(updated_api_key_info.get("tokens_used", 0))
model_tokens_used = int(updated_api_key_info.get(f"{model_name}_tokens_used", 0))
return {
"message": f"{model_name.upper()}请求已排队等待处理",
"tokens_used": tokens_required,
"total_tokens_used": new_tokens_used,
f"{model_name}_tokens_used": model_tokens_used,
"tokens_remaining": total_tokens - new_tokens_used
}
class TTSRequest(BaseModel):
text: str
voice: str = Field(..., description="选择的音色")
class ChatRequest(BaseModel):
session_id: str
query: str
model: str = "qwen2.5:3b"
@v1_chat_app.post("/tts")
async def tts_request(request: TTSRequest, api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
text_hash = get_audio_hash(request.text)
# 验证音色选择
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
if request.voice not in valid_voices:
raise HTTPException(status_code=400, detail="无效的音色选择")
# 如果声音是 'default',则将其视为 'girl'
voice = 'girl' if request.voice == 'default' else request.voice
# 使用对应音色的 Redis 客户端
redis_tts = voice_to_redis[request.voice]
# 检查是否已存在相同内容的音频文件
existing_audio_info = redis_tts.get(f"tts:{text_hash}")
if existing_audio_info:
existing_audio_path = json.loads(existing_audio_info)['path']
if os.path.exists(existing_audio_path):
return {
"message": "TTS请求已完成",
"task_id": task_id,
"status": "completed",
"audio_path": existing_audio_path
}
# 如果不存在,创建新的任务
task_data = {
"task_id": task_id,
"text": request.text,
"text_hash": text_hash,
"voice": request.voice,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
}
# 存储任务信息到Redis
redis_task_client.set(f"task_status:tts:{task_id}", "queued")
result = await process_request(api_key_info, "tts", 1, task_data, KAFKA_TTS_TOPIC)
result["task_id"] = task_id
return result
@v1_chat_app.post("/asr")
async def asr_request(audio: UploadFile = File(...), api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
UPLOAD_DIR = "/obscura/task/audio_upload"
os.makedirs(UPLOAD_DIR, exist_ok=True)
file_path = os.path.join(UPLOAD_DIR, f"{task_id}.wav")
with open(file_path, "wb") as temp_audio:
content = await audio.read()
temp_audio.write(content)
task_data = {
'file_path': file_path,
'task_id': task_id,
'status': 'queued'
}
# 存储任务状态,使用一致的键名格式
redis_task_client.set(f"task_status:asr:{task_id}", "queued")
result = await process_request(api_key_info, "asr", 1, task_data, KAFKA_ASR_TOPIC)
result["task_id"] = task_id
return result
@v1_chat_app.post("/chat")
async def chat_request(request: ChatRequest, api_key_info: dict = Depends(verify_api_key)):
task_id = str(uuid.uuid4())
task_data = {
"task_id": task_id,
"session_id": request.session_id,
"query": request.query,
"model": request.model,
"status": "queued",
"created_at": datetime.now(timezone.utc).isoformat(),
}
# 设置任务状态为 "queued"
redis_task_client.set(f"chat:{task_id}:status", "queued")
result = await process_request(api_key_info, "chat", 1, task_data, KAFKA_CHAT_TOPIC)
result["task_id"] = task_id
return result
@v1_chat_app.get("/chat_result/{task_id}")
async def get_chat_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
# 从Redis任务数据库获取任务状态
task_status = redis_task_client.get(f"chat:{task_id}:status")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从Redis任务数据库获取聊天结果
chat_result = redis_task_client.get(f"chat:{task_id}:result")
if chat_result:
result = json.loads(chat_result)
return {
"status": "completed",
"result": result
}
return {"status": status}
return {"status": "not_found"}
@v1_chat_app.get("/tts_result/{task_id}")
async def get_tts_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
if task_info:
task_data = json.loads(task_info)
text_hash = task_data['text_hash']
voice = task_data['voice']
# 'default' 和 'girl' 都使用 girl 的 Redis
redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice]
audio_info = redis_tts.get(f"tts:{text_hash}")
if audio_info:
audio_path = json.loads(audio_info)['path']
return {
"status": "completed",
"audio_path": audio_path
}
return {"status": status}
return {"status": "not_found"}
@v1_chat_app.get("/asr_result/{task_id}")
async def get_asr_result(task_id: str, api_key_info: dict = Depends(verify_api_key)):
# 从Redis任务数据库获取任务状态,使用一致的键名格式
task_status = redis_task_client.get(f"task_status:asr:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从Redis ASR结果数据库获取转录结果
transcription = redis_asr_client.get(f"asr:{task_id}")
if transcription:
return {
"status": "completed",
"transcription": transcription.decode('utf-8')
}
return {"status": status}
return {"status": "not_found"}
@v1_chat_app.get("/tts_audio/{task_id}")
async def get_tts_audio(task_id: str, api_key_info: dict = Depends(verify_api_key)):
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
if task_status:
status = task_status.decode('utf-8')
if status == "completed":
# 从任务信息中获取使用的音色
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
if task_info:
task_data = json.loads(task_info)
voice = task_data.get('voice', 'girl') # 默认使用 'girl'
# 'default' 和 'girl' 都使用 girl 的 Redis
redis_tts = voice_to_redis['girl'] if voice in ['default', 'girl'] else voice_to_redis[voice]
# 从对应音色的 Redis 获取音频文件路径
audio_info = redis_tts.get(f"tts:{task_data['text_hash']}")
if audio_info:
audio_path = json.loads(audio_info)['path']
if os.path.exists(audio_path):
file_name = os.path.basename(audio_path)
return FileResponse(audio_path, media_type="audio/wav", filename=file_name)
else:
raise HTTPException(status_code=404, detail="音频文件不存在")
elif status == "queued" or status == "processing":
raise HTTPException(status_code=202, detail="音频文件正在生成中")
else:
raise HTTPException(status_code=500, detail="任务处理失败")
raise HTTPException(status_code=404, detail="任务不存在")
@v1_chat_app.get("/getvoice")
async def get_available_voices(api_key_info: dict = Depends(verify_api_key)):
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
return {"available_voices": valid_voices}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8008)
-7
View File
@@ -1,7 +0,0 @@
fastapi
uvicorn
pydantic
kafka-python
redis
python-dotenv
python-multipart
Regular → Executable
View File
Regular → Executable
View File
+1
View File
@@ -0,0 +1 @@
{"GPT": {"v1": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", "v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt"}, "SoVITS": {"v1": "GPT_SoVITS/pretrained_models/s2G488k.pth", "v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth"}}