775 lines
32 KiB
Python
775 lines
32 KiB
Python
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, Depends, status
|
||
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel
|
||
from kafka import KafkaProducer
|
||
import json
|
||
import asyncio
|
||
import redis
|
||
import os
|
||
import time
|
||
import uuid
|
||
from dotenv import load_dotenv
|
||
import requests
|
||
from typing import List
|
||
from pydub import AudioSegment
|
||
from datetime import datetime, timezone
|
||
from fastapi.staticfiles import StaticFiles
|
||
import httpx
|
||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||
from sqlalchemy import create_engine, Column, String, DateTime, Boolean
|
||
from sqlalchemy.orm import declarative_base, sessionmaker, Session
|
||
from sqlalchemy.sql import func
|
||
from passlib.context import CryptContext
|
||
from pydantic import BaseModel, EmailStr
|
||
|
||
|
||
# 加载 .env 文件
|
||
load_dotenv()
|
||
|
||
app = FastAPI()
|
||
user_app = FastAPI()
|
||
app.mount("/user", user_app)
|
||
|
||
# 允许跨域请求
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 允许所有源,您可能想要限制这个
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# Kafka 设置
|
||
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
||
KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC')
|
||
|
||
# Redis 设置
|
||
REDIS_HOST = os.getenv('REDIS_HOST')
|
||
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
||
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
||
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
||
REDIS_SESSION_DB_ZH = int(os.getenv('REDIS_SESSION_DB_ZH'))
|
||
REDIS_SESSION_DB_EN = int(os.getenv('REDIS_SESSION_DB_EN'))
|
||
REDIS_SESSION_DB_KO = int(os.getenv('REDIS_SESSION_DB_KO'))
|
||
REDIS_REGISTER_DB = int(os.getenv('REDIS_REGISTER_DB'))
|
||
REDIS_DB = int(os.getenv('REDIS_DB'))
|
||
|
||
# 创建Redis任务客户端
|
||
redis_task_client = redis.Redis(
|
||
host=REDIS_HOST,
|
||
port=REDIS_PORT,
|
||
db=REDIS_TASK_DB,
|
||
password=REDIS_PASSWORD
|
||
)
|
||
redis_client = redis.Redis(
|
||
host=REDIS_HOST,
|
||
port=REDIS_PORT,
|
||
db=REDIS_DB,
|
||
password=REDIS_PASSWORD
|
||
)
|
||
redis_register_client = redis.Redis(
|
||
host=REDIS_HOST,
|
||
port=REDIS_PORT,
|
||
db=REDIS_REGISTER_DB,
|
||
password=REDIS_PASSWORD
|
||
)
|
||
|
||
|
||
# 为不同语言创建Redis会话客户端
|
||
redis_session_clients = {
|
||
'zh': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_SESSION_DB_ZH, password=REDIS_PASSWORD),
|
||
'en': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_SESSION_DB_EN, password=REDIS_PASSWORD),
|
||
'ko': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_SESSION_DB_KO, password=REDIS_PASSWORD)
|
||
}
|
||
|
||
# 为不同的语音创建Redis客户端
|
||
voice_to_redis = {
|
||
'girl': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_GIRL_DB')), password=REDIS_PASSWORD),
|
||
'woman': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_WOMAN_DB')), password=REDIS_PASSWORD),
|
||
'man': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_MAN_DB')), password=REDIS_PASSWORD),
|
||
'leijun': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_LEIJUN_DB')), password=REDIS_PASSWORD),
|
||
'dufu': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_DUFU_DB')), password=REDIS_PASSWORD),
|
||
'hejiong': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_HEJIONG_DB')), password=REDIS_PASSWORD),
|
||
'mahuateng': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_MAHUATENG_DB')), password=REDIS_PASSWORD),
|
||
'lidan': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_LIDAN_DB')), password=REDIS_PASSWORD),
|
||
'luoxiang': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_LUOXIANG_DB')), password=REDIS_PASSWORD),
|
||
'xuzhiyuan': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_XUZHIYUAN_DB')), password=REDIS_PASSWORD),
|
||
'dabing': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_DABING_DB')), password=REDIS_PASSWORD),
|
||
'yuhua': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_YUHUA_DB')), password=REDIS_PASSWORD)
|
||
}
|
||
# 音频文件径
|
||
AUDIO_BASE_PATH = "/obscura/task/audio_files"
|
||
AVATAR_BASE_PATH = "/obscura/task/avatar"
|
||
|
||
# 定义JSON文件存储路径
|
||
TEAM_MEMBERS_PATH = "team_members"
|
||
|
||
# 定义角色名称到文件名的映射
|
||
ROLE_TO_FILENAME = {
|
||
"技术专家": "tech_expert",
|
||
"创意专家": "creative",
|
||
"数据分析专家": "dataanalyst",
|
||
"项目规划专家": "pragmatist",
|
||
"市场营销专家": "marketing_expert",
|
||
"财务专家": "financial_expert",
|
||
"马化腾": "mahuateng",
|
||
"李诞": "lidan",
|
||
"罗翔": "luoxiang",
|
||
"许知远": "xuzhiyuan",
|
||
"大冰": "dabing",
|
||
"余华": "yuhua",
|
||
"雷军": "leijun"
|
||
}
|
||
|
||
def load_team_members(language='zh', selected_roles=None):
|
||
team_members = {}
|
||
|
||
# 添加主持人信息
|
||
leader_info = {
|
||
'name': {
|
||
'zh': '何主持',
|
||
'en': 'Host He',
|
||
'ko': '사회자 허'
|
||
},
|
||
'post': {
|
||
'zh': '主持人',
|
||
'en': 'Host',
|
||
'ko': '사회자'
|
||
},
|
||
'voice': 'hejiong',
|
||
'avatar': 'cn__00138_.png',
|
||
'personality': {
|
||
'zh': """经验丰富,专门负责引导六人跨部门小组的头脑风暴会议。你的任务是有效地主持讨论,基于我提出的具体问题展开对话。请记住,你的回应和引导必须简洁明了。
|
||
作为主持人,你应该:
|
||
用简短的话语开场,迅速将注意力集中到我提出的问题上,
|
||
你的目标是通过简洁有力的引导,创造一个富有成效的讨论环境,让每个团队成员都能围绕我提出的问题贡献有价值的见解。""",
|
||
'en': """Experienced in guiding brainstorming sessions for a six-person cross-departmental team. Your task is to effectively moderate the discussion based on the specific questions I raise. Remember, your responses and guidance must be concise and clear.
|
||
As the host, you should:
|
||
Open with brief remarks, quickly focusing attention on the question I've posed,
|
||
Your goal is to create a productive discussion environment through concise and powerful guidance, allowing each team member to contribute valuable insights around the question I've raised.""",
|
||
'ko': """6인 부서 간 팀의 브레인스토밍 회의를 이끄는 데 경험이 풍부합니다. 당신의 임무는 내가 제기한 구체적인 질문 바탕으로 토론을 효과적으로 진행하는 것입니다. 당신의 응답과 안내는 간결하고 명확해야 함을 기억하세요.
|
||
사회자로서 당신은:
|
||
간단한 말로 시작하여 내가 제기한 질문에 빠르게 주의를 집중시켜야 합니다,
|
||
당신의 목표는 간결하고 강력한 안내를 통해 생산적인 토론 환경을 만들어, 각 팀원이 내가 제기한 질문에 대해 가치 있는 통찰력을 기여할 수 있도록 하는 것입니다."""
|
||
}
|
||
}
|
||
|
||
if selected_roles is None or '主持人' in selected_roles:
|
||
team_members['主持人'] = {
|
||
'name': leader_info['name'][language],
|
||
'post': leader_info['post'][language],
|
||
'voice': leader_info['voice'],
|
||
'avatar': leader_info['avatar'],
|
||
'personality': leader_info['personality'][language]
|
||
}
|
||
|
||
lang_path = os.path.join(TEAM_MEMBERS_PATH, language)
|
||
|
||
# 如果没有指定角色,加载所有角色(除了主持人
|
||
roles_to_load = selected_roles if selected_roles else list(ROLE_TO_FILENAME.keys())
|
||
|
||
for role in roles_to_load:
|
||
if role != '主持人' and role in ROLE_TO_FILENAME:
|
||
filename = f"{ROLE_TO_FILENAME[role]}.json"
|
||
file_path = os.path.join(lang_path, filename)
|
||
if os.path.exists(file_path):
|
||
with open(file_path, 'r', encoding='utf-8') as file:
|
||
member_data = json.load(file)
|
||
team_members[role] = member_data
|
||
|
||
return team_members
|
||
|
||
# 创建Kafka生产者
|
||
producer = KafkaProducer(
|
||
bootstrap_servers=KAFKA_BROKER,
|
||
value_serializer=lambda v: json.dumps(v).encode('utf-8')
|
||
)
|
||
|
||
class DiscussionRequest(BaseModel):
|
||
topic: str
|
||
max_rounds: int
|
||
session_id: str
|
||
selected_roles: List[str]
|
||
language: str # 新言字段
|
||
|
||
# 修改常量
|
||
SILICONFLOW_API_KEY = "sk-ytxabphvgxrjbvnqiwercjyrabvlukwddqsmvnqnvwuazamd"
|
||
SILICONFLOW_API_URL = "https://api.siliconflow.cn/v1/chat/completions"
|
||
|
||
# 修改 get_ai_response 函数
|
||
async def get_ai_response(prompt, model='deepseek-ai/DeepSeek-V2.5', language='zh'):
|
||
language_prompts = {
|
||
'zh': "请用中文回答以下问题:",
|
||
'en': "Please answer the following question in English:",
|
||
'ko': "다음 질문에 한국어로 답해주세요:"
|
||
}
|
||
lang_prompt = language_prompts.get(language, language_prompts['zh'])
|
||
|
||
payload = {
|
||
"model": model,
|
||
"messages": [
|
||
{"role": "system", "content": lang_prompt},
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
"stream": False,
|
||
"max_tokens": 512,
|
||
"stop": ["<string>"],
|
||
"temperature": 0.7,
|
||
"top_p": 0.7,
|
||
"top_k": 50,
|
||
"frequency_penalty": 0.5,
|
||
"n": 1,
|
||
"response_format": {"type": "text"} # 改为 "text" 而不是 "json_object"
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {SILICONFLOW_API_KEY}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
response = requests.post(SILICONFLOW_API_URL, json=payload, headers=headers)
|
||
response.raise_for_status()
|
||
|
||
response_data = response.json()
|
||
if 'choices' in response_data and len(response_data['choices']) > 0:
|
||
content = response_data['choices'][0]['message']['content']
|
||
# 尝试解析JSON,如果失败则直接返回内容
|
||
try:
|
||
json_content = json.loads(content)
|
||
return json_content.get('response', content) # 如果是JSON,尝试获取'response'字段
|
||
except json.JSONDecodeError:
|
||
return content # 如果不是JSON,直接返回内容
|
||
else:
|
||
return "无法生成回应。请重试。"
|
||
|
||
async def wait_for_audio_file(task_id, timeout=30):
|
||
start_time = time.time()
|
||
while time.time() - start_time < timeout:
|
||
audio_files = [f for f in os.listdir(AUDIO_BASE_PATH) if f.startswith(task_id) and f.endswith('.wav')]
|
||
if audio_files:
|
||
return audio_files[0]
|
||
await asyncio.sleep(0.5)
|
||
return None
|
||
|
||
async def concatenate_audio_files(audio_files, output_path):
|
||
combined = AudioSegment.empty()
|
||
for audio_file in audio_files:
|
||
segment = AudioSegment.from_wav(audio_file)
|
||
combined += segment
|
||
combined.export(output_path, format="wav")
|
||
|
||
async def team_discussion_generator(topic, max_rounds, selected_roles, language):
|
||
# 加载选中的团队成员
|
||
team_members = load_team_members(language, selected_roles)
|
||
|
||
discussion_topics = {
|
||
'zh': "讨论主题",
|
||
'en': "Discussion Topic",
|
||
'ko': "토론 주제"
|
||
}
|
||
topic_header = discussion_topics.get(language, discussion_topics['zh'])
|
||
|
||
discussion = [f"{topic_header}: {topic}"]
|
||
yield json.dumps({"topic": f"{topic_header}: {topic}\n", "audio": None})
|
||
|
||
for round in range(max_rounds):
|
||
for role, info in team_members.items():
|
||
prompt = f"""你是一个团队中的{info['post']},{info['personality']}。
|
||
团队正在讨论以下问题:"{topic}",请用{language}语言回答
|
||
当前讨论进展:
|
||
{''.join(discussion)}
|
||
|
||
请根据你的角色和特点,对这个问题发表你的看法或对其他成员的观点进行回应,在多轮讨论中,应该根据讨论历史不断优化回答。
|
||
在回答时打招呼内容不要带有自己的角色和别人的角色,语气应避免单调和机械,尽可能口语化,回答内容保持简洁,控制在150字以内
|
||
"""
|
||
|
||
response = await get_ai_response(prompt, language=language)
|
||
discussion.append(f"\n{info['name']}({info['post']}):{response}")
|
||
|
||
# 生成语音
|
||
tts_task_id = str(uuid.uuid4())
|
||
producer.send(KAFKA_TTS_TOPIC, {
|
||
'task_id': tts_task_id,
|
||
'text': response,
|
||
'text_hash': tts_task_id,
|
||
'voice': info['voice']
|
||
})
|
||
|
||
# 等待音频生成完成
|
||
while True:
|
||
task_status = redis_task_client.get(f"task_status:tts:{tts_task_id}")
|
||
if task_status:
|
||
status = task_status.decode('utf-8')
|
||
if status == "completed":
|
||
task_info = redis_task_client.get(f"task_info:tts:{tts_task_id}")
|
||
if task_info:
|
||
task_data = json.loads(task_info)
|
||
text_hash = task_data['text_hash']
|
||
voice = task_data['voice']
|
||
redis_tts = voice_to_redis[voice]
|
||
|
||
audio_info = redis_tts.get(f"tts:{text_hash}")
|
||
if audio_info:
|
||
audio_path = json.loads(audio_info)['path']
|
||
|
||
output = {
|
||
"post": f"{info['post']}",
|
||
"chunk": f"{info['name']}:{response}",
|
||
"audio_task_id": tts_task_id,
|
||
"avatar": info['avatar'] # 添加头像信息
|
||
}
|
||
print(json.dumps({
|
||
"type": "output",
|
||
"content": output
|
||
}, ensure_ascii=False, indent=2))
|
||
yield json.dumps(output)
|
||
|
||
break
|
||
elif status == "failed":
|
||
output = {
|
||
"post": f"{info['post']}",
|
||
"chunk": f"{info['name']}:{response}",
|
||
"audio": None,
|
||
"avatar": info['avatar'] # 添加头像信息
|
||
}
|
||
yield json.dumps(output)
|
||
break
|
||
await asyncio.sleep(0.5)
|
||
|
||
# 让主持人结束讨论
|
||
leader_info = team_members['主持人']
|
||
summary_prompts = {
|
||
'zh': "作为主持人,请用中文总结团队的讨论并给出最终的结论或建议。",
|
||
'en': "As the moderator, please summarize the team's discussion in English and provide final conclusions or recommendations.",
|
||
'ko': "사회자로서 팀 토론을 한국어로 요약하고 최종 결론이나 권장 사항을 제시해 주세요."
|
||
}
|
||
summary_prompt = f"""{summary_prompts.get(language, summary_prompts['zh'])}讨论内容如下:
|
||
{''.join(discussion)}
|
||
你应该:
|
||
保持客观性,不添加个人观点或偏见;
|
||
使用简短明了语言,避免冗长或重复;
|
||
保留每个角色关键信息主要论点;
|
||
按照逻辑顺序组织信息,使总结易于理解;
|
||
根据文本的长度和复杂程度,调整总结的详细程度;
|
||
提供项目规划专家的流程图。
|
||
"""
|
||
try:
|
||
summary = await get_ai_response(summary_prompt, language=language)
|
||
print(f"生成的总结: {summary}") # 添加日志
|
||
if not summary or len(summary.strip()) < 10: # 检查总结是否为空或过短
|
||
raise ValueError("生成的总结过短或为空")
|
||
discussion.append(f"\n总结:{summary}")
|
||
except Exception as e:
|
||
print(f"生成总结时发生错误: {str(e)}")
|
||
summary = "很抱歉,生成总结时遇到了问题。请查看之前的讨论内容作为参考。"
|
||
discussion.append(f"\n总结:{summary}")
|
||
|
||
# 生成总结的语音
|
||
summary_tts_task_id = str(uuid.uuid4())
|
||
producer.send(KAFKA_TTS_TOPIC, {
|
||
'task_id': summary_tts_task_id,
|
||
'text': summary,
|
||
'text_hash': summary_tts_task_id,
|
||
'voice': leader_info['voice']
|
||
})
|
||
|
||
# 等待总结音频生成完成
|
||
summary_audio_generated = False
|
||
while not summary_audio_generated:
|
||
task_status = redis_task_client.get(f"task_status:tts:{summary_tts_task_id}")
|
||
if task_status:
|
||
status = task_status.decode('utf-8')
|
||
if status == "completed":
|
||
task_info = redis_task_client.get(f"task_info:tts:{summary_tts_task_id}")
|
||
if task_info:
|
||
task_data = json.loads(task_info)
|
||
text_hash = task_data['text_hash']
|
||
voice = task_data['voice']
|
||
redis_tts = voice_to_redis[voice]
|
||
|
||
audio_info = redis_tts.get(f"tts:{text_hash}")
|
||
if audio_info:
|
||
audio_path = json.loads(audio_info)['path']
|
||
summary_json = {
|
||
"post": leader_info['post'],
|
||
"chunk": f"{leader_info['name']}:{summary}",
|
||
"audio_task_id": summary_tts_task_id,
|
||
"avatar": leader_info['avatar']
|
||
}
|
||
print(summary_json)
|
||
yield json.dumps(summary_json)
|
||
summary_audio_generated = True
|
||
elif status == "failed":
|
||
summary_json = {
|
||
"post": leader_info['post'],
|
||
"chunk": f"{leader_info['name']}:{summary}",
|
||
"audio_task_id": None,
|
||
"avatar": leader_info['avatar']
|
||
}
|
||
yield json.dumps(summary_json)
|
||
summary_audio_generated = True
|
||
await asyncio.sleep(0.5)
|
||
|
||
completion_json = {"chunk": "[DISCUSSION_COMPLETED]", "audio_task_id": summary_tts_task_id}
|
||
print(f"发送讨论完成信号: {json.dumps(completion_json)}") # 打印讨论完成信号
|
||
yield json.dumps(completion_json)
|
||
|
||
async def discussion_stream(topic: str, max_rounds: int, session_id: str, selected_roles: List[str], language: str):
|
||
discussion_content = []
|
||
audio_files = []
|
||
async for chunk in team_discussion_generator(topic, max_rounds, selected_roles, language):
|
||
chunk_data = json.loads(chunk)
|
||
if "topic" in chunk_data:
|
||
yield f"data: {chunk}\n\n".encode('utf-8')
|
||
elif "chunk" in chunk_data:
|
||
if chunk_data["chunk"] == "[DISCUSSION_COMPLETED]":
|
||
# Concatenate audio files
|
||
output_path = os.path.join(AUDIO_BASE_PATH, f"{session_id}_combined.wav")
|
||
await concatenate_audio_files(audio_files, output_path)
|
||
|
||
# 保存讨论内容到对应语言的 Redis
|
||
discussion = {
|
||
"topic": topic,
|
||
"content": discussion_content,
|
||
"timestamp": int(datetime.now(timezone.utc).timestamp()), # Current timestamp in seconds
|
||
"combined_audio_path": output_path,
|
||
"language": language
|
||
}
|
||
redis_session_clients[language].set(session_id, json.dumps(discussion))
|
||
print(f"讨论容已保存 {language} 数据库,合并音频路径: {output_path}")
|
||
completion = {"chunk": "[AUDIO_COMPLETED]"}
|
||
yield json.dumps(completion)
|
||
break
|
||
else:
|
||
discussion_content.append(chunk_data)
|
||
if "audio_task_id" in chunk_data and chunk_data["audio_task_id"]:
|
||
audio_path = await get_audio_path(chunk_data["audio_task_id"])
|
||
if audio_path:
|
||
audio_files.append(audio_path)
|
||
yield f"data: {chunk}\n\n".encode('utf-8')
|
||
else:
|
||
# 处理其他可能的数据结构
|
||
discussion_content.append(chunk_data)
|
||
yield f"data: {chunk}\n\n".encode('utf-8')
|
||
|
||
async def get_audio_path(task_id):
|
||
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
|
||
if task_info:
|
||
task_data = json.loads(task_info)
|
||
voice = task_data['voice']
|
||
redis_tts = voice_to_redis[voice]
|
||
audio_info = redis_tts.get(f"tts:{task_data['text_hash']}")
|
||
if audio_info:
|
||
return json.loads(audio_info)['path']
|
||
return None
|
||
|
||
@user_app.get("/api/start-discussion")
|
||
@user_app.post("/api/start-discussion")
|
||
async def start_discussion(request: DiscussionRequest):
|
||
if not request.topic:
|
||
raise HTTPException(status_code=400, detail="Topic is required")
|
||
if not request.session_id:
|
||
raise HTTPException(status_code=400, detail="Session ID is required")
|
||
if not request.selected_roles:
|
||
raise HTTPException(status_code=400, detail="At least one role must be selected")
|
||
if not request.language or request.language not in ['zh', 'en', 'ko']:
|
||
request.language = 'zh' # 默认使用中文
|
||
|
||
# 根据选择的语言和角色加载团队成员信息
|
||
global TEAM_MEMBERS
|
||
TEAM_MEMBERS = load_team_members(request.language)
|
||
|
||
return StreamingResponse(
|
||
discussion_stream(request.topic, request.max_rounds, request.session_id, request.selected_roles, request.language),
|
||
media_type="text/event-stream"
|
||
)
|
||
|
||
# 获取指定任务ID的音频文件路径
|
||
@user_app.get("/api/get-audio/{task_id}", response_model=List[str])
|
||
async def get_audio(task_id: str):
|
||
# 从Redis中获取任务状态
|
||
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
|
||
if task_status:
|
||
status = task_status.decode('utf-8')
|
||
if status == "completed":
|
||
# 如果任务已完成,获取任务信息
|
||
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
|
||
if task_info:
|
||
task_data = json.loads(task_info)
|
||
voice = task_data['voice']
|
||
redis_tts = voice_to_redis[voice]
|
||
# 从Redis中获取音频信息
|
||
audio_info = redis_tts.get(f"tts:{task_data['text_hash']}")
|
||
if audio_info:
|
||
audio_path = json.loads(audio_info)['path']
|
||
# 检查音频文件是否存在
|
||
if os.path.exists(audio_path):
|
||
# 返回音频文件的URL路径
|
||
return [f"/audio/{os.path.basename(audio_path)}"]
|
||
elif status == "queued" or status == "processing":
|
||
# 如果任务正在队列中或处理中,返回202状态码
|
||
raise HTTPException(status_code=202, detail="音频文件正在生成中")
|
||
else:
|
||
# 如果任务状态异常,返回500错误
|
||
raise HTTPException(status_code=500, detail="任务处理失败")
|
||
|
||
# 如果任务不存在,返回404错误
|
||
raise HTTPException(status_code=404, detail="任务不存在")
|
||
|
||
# 提供音频文件的下载服务
|
||
@user_app.get("/api/get-combined-audio/{session_id}")
|
||
async def get_combined_audio(session_id: str, language: str = 'zh'):
|
||
if language not in redis_session_clients:
|
||
raise HTTPException(status_code=400, detail="Unsupported language")
|
||
|
||
session_data = redis_session_clients[language].get(session_id)
|
||
if session_data:
|
||
session_info = json.loads(session_data.decode('utf-8'))
|
||
audio_path = session_info.get("combined_audio_path")
|
||
print(audio_path)
|
||
if audio_path and os.path.exists(audio_path):
|
||
return FileResponse(audio_path, media_type="audio/wav", filename=f"{session_id}_combined.wav")
|
||
raise HTTPException(status_code=404, detail="Combined audio not found")
|
||
|
||
# 获取可用的音色列表
|
||
@user_app.get("/getvoice")
|
||
async def get_available_voices():
|
||
# 定义有效的音色列表
|
||
valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"]
|
||
# 返回可用的音色列表
|
||
return {"available_voices": valid_voices}
|
||
|
||
@user_app.get("/tts_result/{task_id}")
|
||
async def get_tts_result(task_id: str):
|
||
task_status = redis_task_client.get(f"task_status:tts:{task_id}")
|
||
if task_status:
|
||
status = task_status.decode('utf-8')
|
||
return {"status": status}
|
||
else:
|
||
raise HTTPException(status_code=404, detail="任务不存在")
|
||
|
||
@user_app.get("/tts_audio/{task_id}")
|
||
async def get_tts_audio(task_id: str):
|
||
task_info = redis_task_client.get(f"task_info:tts:{task_id}")
|
||
if task_info:
|
||
task_data = json.loads(task_info)
|
||
text_hash = task_data['text_hash']
|
||
voice = task_data['voice']
|
||
redis_tts = voice_to_redis[voice]
|
||
|
||
audio_info = redis_tts.get(f"tts:{text_hash}")
|
||
if audio_info:
|
||
audio_path = json.loads(audio_info)['path']
|
||
return FileResponse(audio_path, media_type="audio/wav")
|
||
|
||
raise HTTPException(status_code=404, detail="音频文件不存在")
|
||
|
||
@user_app.get("/api/get-discussion/{session_id}")
|
||
async def get_discussion(session_id: str, language: str = 'zh'):
|
||
if language not in redis_session_clients:
|
||
raise HTTPException(status_code=400, detail="Unsupported language")
|
||
|
||
discussion_content = redis_session_clients[language].get(session_id)
|
||
if discussion_content:
|
||
try:
|
||
discussion_data = json.loads(discussion_content)
|
||
return JSONResponse(content=discussion_data)
|
||
except json.JSONDecodeError:
|
||
raise HTTPException(status_code=500, detail="Invalid JSON data in Redis")
|
||
raise HTTPException(status_code=404, detail="Discussion not found")
|
||
|
||
@user_app.get("/api/get-session-id")
|
||
async def get_session_id(language: str = 'zh'):
|
||
if language not in redis_session_clients:
|
||
raise HTTPException(status_code=400, detail="Unsupported language")
|
||
|
||
redis_client = redis_session_clients[language]
|
||
result = []
|
||
try:
|
||
all_keys = redis_client.keys('*')
|
||
session_ids = [key.decode('utf-8') for key in all_keys]
|
||
|
||
for session_id in session_ids:
|
||
value = redis_client.get(session_id)
|
||
if value:
|
||
value_json = json.loads(value.decode('utf-8'))
|
||
|
||
if isinstance(value_json, dict) and 'topic' in value_json and 'timestamp' in value_json:
|
||
topic = value_json['topic']
|
||
utc_time = datetime.fromtimestamp(value_json['timestamp'], tz=timezone.utc)
|
||
timestamp = utc_time.strftime("%Y-%m-%d %H:%M:%S UTC")
|
||
result.append({
|
||
"session_id": session_id,
|
||
"topic": topic,
|
||
"timestamp": timestamp,
|
||
"language": language
|
||
})
|
||
except Exception as e:
|
||
print(f"Error fetching {language} session data: {e}")
|
||
raise HTTPException(status_code=500, detail=f"Error fetching session data for {language}")
|
||
|
||
if not result:
|
||
raise HTTPException(status_code=404, detail=f"No session data found for {language}")
|
||
|
||
return JSONResponse(content={"sessions": result})
|
||
|
||
# 确保这个路径是正确的,并且文件夹存在
|
||
AVATAR_BASE_PATH = "/obscura/task/avatar"
|
||
|
||
# 在挂载之前添加一些调试信息
|
||
print(f"Avatar base path: {AVATAR_BASE_PATH}")
|
||
print(f"Avatar directory exists: {os.path.exists(AVATAR_BASE_PATH)}")
|
||
print(f"Avatar directory contents: {os.listdir(AVATAR_BASE_PATH)}")
|
||
|
||
# 将头像文件夹挂载到 /avatar 路径
|
||
user_app.mount("/avatar", StaticFiles(directory=AVATAR_BASE_PATH), name="avatar")
|
||
|
||
# 数据库设置
|
||
DB_USER = os.getenv("DB_USER")
|
||
DB_PASSWORD = os.getenv("DB_PASSWORD")
|
||
DB_HOST = os.getenv("DB_HOST")
|
||
DB_NAME = os.getenv("DB_NAME")
|
||
DB_PORT = os.getenv("DB_PORT")
|
||
|
||
SQLALCHEMY_DATABASE_URL = f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
|
||
engine = create_engine(SQLALCHEMY_DATABASE_URL)
|
||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||
|
||
Base = declarative_base()
|
||
|
||
# 密码哈希配置
|
||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||
|
||
# OAuth2密码流配置
|
||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||
|
||
# 用户模型
|
||
class User(Base):
|
||
__tablename__ = "user"
|
||
uuid = Column(String(36), primary_key=True, index=True)
|
||
username = Column(String(50), unique=True, index=True, nullable=False)
|
||
email = Column(String(100), unique=True, index=True)
|
||
hashed_password = Column(String(255))
|
||
created_at = Column(DateTime, server_default=func.now())
|
||
last_login = Column(DateTime)
|
||
is_active = Column(Boolean, default=False)
|
||
|
||
# 创建数据库表
|
||
Base.metadata.create_all(bind=engine)
|
||
|
||
# 用户创建模型
|
||
class UserCreate(BaseModel):
|
||
username: str
|
||
email: EmailStr
|
||
password: str
|
||
|
||
# 数据库会话依赖
|
||
def get_db():
|
||
db = SessionLocal()
|
||
try:
|
||
yield db
|
||
finally:
|
||
db.close()
|
||
|
||
# 密码验证函数
|
||
def verify_password(plain_password, hashed_password):
|
||
return pwd_context.verify(plain_password, hashed_password)
|
||
|
||
# 密码哈希函数
|
||
def get_password_hash(password):
|
||
return pwd_context.hash(password)
|
||
|
||
# 根据用户名获取用户
|
||
def get_user(db: Session, username: str):
|
||
return db.query(User).filter(User.username == username).first()
|
||
|
||
# 用户认证函数
|
||
def authenticate_user(db: Session, username: str, password: str):
|
||
user = get_user(db, username)
|
||
if not user:
|
||
return False
|
||
if not verify_password(password, user.hashed_password):
|
||
return False
|
||
return user
|
||
|
||
# 获取当前用户
|
||
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
|
||
user_info = redis_client.get(f"session:{token}")
|
||
if not user_info:
|
||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的token")
|
||
|
||
user_data = json.loads(user_info)
|
||
user = get_user(db, user_data['username'])
|
||
if user is None:
|
||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在")
|
||
|
||
return user
|
||
|
||
@user_app.post("/register")
|
||
def register(user: UserCreate, db: Session = Depends(get_db)):
|
||
db_user = get_user(db, username=user.username)
|
||
if db_user:
|
||
raise HTTPException(status_code=400, detail="用户名已被注册")
|
||
|
||
hashed_password = get_password_hash(user.password)
|
||
new_user = User(
|
||
uuid=str(uuid.uuid4()),
|
||
username=user.username,
|
||
email=user.email,
|
||
hashed_password=hashed_password
|
||
)
|
||
db.add(new_user)
|
||
db.commit()
|
||
db.refresh(new_user)
|
||
|
||
return {"message": "用户创建成功"}
|
||
|
||
@user_app.post("/token")
|
||
def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||
user = authenticate_user(db, form_data.username, form_data.password)
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="用户名或密码不正确",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
session_id = str(uuid.uuid4())
|
||
|
||
user_info = {
|
||
"id": user.uuid,
|
||
"username": user.username,
|
||
"is_active": user.is_active
|
||
}
|
||
redis_client.setex(f"session:{session_id}", 604800, json.dumps(user_info)) # 设置7天过期
|
||
|
||
user.last_login = datetime.now(timezone.utc)
|
||
db.commit()
|
||
|
||
return {"access_token": session_id, "token_type": "bearer"}
|
||
|
||
@user_app.get("/me")
|
||
async def read_users_me(current_user: User = Depends(get_current_user)):
|
||
return current_user
|
||
|
||
@user_app.get("/login")
|
||
async def login_page():
|
||
return FileResponse("login.html")
|
||
|
||
@user_app.post("/logout")
|
||
async def logout(token: str = Depends(oauth2_scheme)):
|
||
redis_client.delete(f"session:{token}")
|
||
return {"message": "Successfully logged out"}
|
||
|
||
@user_app.get("/user-info")
|
||
async def get_user_info(current_user: User = Depends(get_current_user)):
|
||
return {
|
||
"username": current_user.username,
|
||
"email": current_user.email,
|
||
"avatar": current_user.username[0].upper() # 使用用户名的首字母
|
||
}
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=8000) |