from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, Depends, status from fastapi.responses import JSONResponse, StreamingResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from kafka import KafkaProducer import json import asyncio import redis import os import time import uuid from dotenv import load_dotenv import requests from typing import List from pydub import AudioSegment from datetime import datetime, timezone from fastapi.staticfiles import StaticFiles import httpx from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from sqlalchemy import create_engine, Column, String, DateTime, Boolean from sqlalchemy.orm import declarative_base, sessionmaker, Session from sqlalchemy.sql import func from passlib.context import CryptContext from pydantic import BaseModel, EmailStr # 加载 .env 文件 load_dotenv() app = FastAPI() user_app = FastAPI() app.mount("/user", user_app) # 允许跨域请求 app.add_middleware( CORSMiddleware, allow_origins=["*"], # 允许所有源,您可能想要限制这个 allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Kafka 设置 KAFKA_BROKER = os.getenv('KAFKA_BROKER') KAFKA_TTS_TOPIC = os.getenv('KAFKA_TTS_TOPIC') # Redis 设置 REDIS_HOST = os.getenv('REDIS_HOST') REDIS_PORT = int(os.getenv('REDIS_PORT')) REDIS_PASSWORD = os.getenv('REDIS_PASSWORD') REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB')) REDIS_SESSION_DB_ZH = int(os.getenv('REDIS_SESSION_DB_ZH')) REDIS_SESSION_DB_EN = int(os.getenv('REDIS_SESSION_DB_EN')) REDIS_SESSION_DB_KO = int(os.getenv('REDIS_SESSION_DB_KO')) REDIS_REGISTER_DB = int(os.getenv('REDIS_REGISTER_DB')) REDIS_DB = int(os.getenv('REDIS_DB')) # 创建Redis任务客户端 redis_task_client = redis.Redis( host=REDIS_HOST, port=REDIS_PORT, db=REDIS_TASK_DB, password=REDIS_PASSWORD ) redis_client = redis.Redis( host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, password=REDIS_PASSWORD ) redis_register_client = redis.Redis( host=REDIS_HOST, port=REDIS_PORT, db=REDIS_REGISTER_DB, password=REDIS_PASSWORD ) # 为不同语言创建Redis会话客户端 redis_session_clients = { 'zh': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_SESSION_DB_ZH, password=REDIS_PASSWORD), 'en': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_SESSION_DB_EN, password=REDIS_PASSWORD), 'ko': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_SESSION_DB_KO, password=REDIS_PASSWORD) } # 为不同的语音创建Redis客户端 voice_to_redis = { 'girl': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_GIRL_DB')), password=REDIS_PASSWORD), 'woman': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_WOMAN_DB')), password=REDIS_PASSWORD), 'man': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_MAN_DB')), password=REDIS_PASSWORD), 'leijun': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_LEIJUN_DB')), password=REDIS_PASSWORD), 'dufu': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_DUFU_DB')), password=REDIS_PASSWORD), 'hejiong': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_HEJIONG_DB')), password=REDIS_PASSWORD), 'mahuateng': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_MAHUATENG_DB')), password=REDIS_PASSWORD), 'lidan': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_LIDAN_DB')), password=REDIS_PASSWORD), 'luoxiang': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_LUOXIANG_DB')), password=REDIS_PASSWORD), 'xuzhiyuan': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_XUZHIYUAN_DB')), password=REDIS_PASSWORD), 'dabing': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_DABING_DB')), password=REDIS_PASSWORD), 'yuhua': redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=int(os.getenv('REDIS_YUHUA_DB')), password=REDIS_PASSWORD) } # 音频文件径 AUDIO_BASE_PATH = "/obscura/task/audio_files" AVATAR_BASE_PATH = "/obscura/task/avatar" # 定义JSON文件存储路径 TEAM_MEMBERS_PATH = "team_members" # 定义角色名称到文件名的映射 ROLE_TO_FILENAME = { "技术专家": "tech_expert", "创意专家": "creative", "数据分析专家": "dataanalyst", "项目规划专家": "pragmatist", "市场营销专家": "marketing_expert", "财务专家": "financial_expert", "马化腾": "mahuateng", "李诞": "lidan", "罗翔": "luoxiang", "许知远": "xuzhiyuan", "大冰": "dabing", "余华": "yuhua", "雷军": "leijun" } def load_team_members(language='zh', selected_roles=None): team_members = {} # 添加主持人信息 leader_info = { 'name': { 'zh': '何主持', 'en': 'Host He', 'ko': '사회자 허' }, 'post': { 'zh': '主持人', 'en': 'Host', 'ko': '사회자' }, 'voice': 'hejiong', 'avatar': 'cn__00138_.png', 'personality': { 'zh': """经验丰富,专门负责引导六人跨部门小组的头脑风暴会议。你的任务是有效地主持讨论,基于我提出的具体问题展开对话。请记住,你的回应和引导必须简洁明了。 作为主持人,你应该: 用简短的话语开场,迅速将注意力集中到我提出的问题上, 你的目标是通过简洁有力的引导,创造一个富有成效的讨论环境,让每个团队成员都能围绕我提出的问题贡献有价值的见解。""", 'en': """Experienced in guiding brainstorming sessions for a six-person cross-departmental team. Your task is to effectively moderate the discussion based on the specific questions I raise. Remember, your responses and guidance must be concise and clear. As the host, you should: Open with brief remarks, quickly focusing attention on the question I've posed, Your goal is to create a productive discussion environment through concise and powerful guidance, allowing each team member to contribute valuable insights around the question I've raised.""", 'ko': """6인 부서 간 팀의 브레인스토밍 회의를 이끄는 데 경험이 풍부합니다. 당신의 임무는 내가 제기한 구체적인 질문 바탕으로 토론을 효과적으로 진행하는 것입니다. 당신의 응답과 안내는 간결하고 명확해야 함을 기억하세요. 사회자로서 당신은: 간단한 말로 시작하여 내가 제기한 질문에 빠르게 주의를 집중시켜야 합니다, 당신의 목표는 간결하고 강력한 안내를 통해 생산적인 토론 환경을 만들어, 각 팀원이 내가 제기한 질문에 대해 가치 있는 통찰력을 기여할 수 있도록 하는 것입니다.""" } } if selected_roles is None or '主持人' in selected_roles: team_members['主持人'] = { 'name': leader_info['name'][language], 'post': leader_info['post'][language], 'voice': leader_info['voice'], 'avatar': leader_info['avatar'], 'personality': leader_info['personality'][language] } lang_path = os.path.join(TEAM_MEMBERS_PATH, language) # 如果没有指定角色,加载所有角色(除了主持人 roles_to_load = selected_roles if selected_roles else list(ROLE_TO_FILENAME.keys()) for role in roles_to_load: if role != '主持人' and role in ROLE_TO_FILENAME: filename = f"{ROLE_TO_FILENAME[role]}.json" file_path = os.path.join(lang_path, filename) if os.path.exists(file_path): with open(file_path, 'r', encoding='utf-8') as file: member_data = json.load(file) team_members[role] = member_data return team_members # 创建Kafka生产者 producer = KafkaProducer( bootstrap_servers=KAFKA_BROKER, value_serializer=lambda v: json.dumps(v).encode('utf-8') ) class DiscussionRequest(BaseModel): topic: str max_rounds: int session_id: str selected_roles: List[str] language: str # 新言字段 # 修改常量 SILICONFLOW_API_KEY = "sk-ytxabphvgxrjbvnqiwercjyrabvlukwddqsmvnqnvwuazamd" SILICONFLOW_API_URL = "https://api.siliconflow.cn/v1/chat/completions" # 修改 get_ai_response 函数 async def get_ai_response(prompt, model='deepseek-ai/DeepSeek-V2.5', language='zh'): language_prompts = { 'zh': "请用中文回答以下问题:", 'en': "Please answer the following question in English:", 'ko': "다음 질문에 한국어로 답해주세요:" } lang_prompt = language_prompts.get(language, language_prompts['zh']) payload = { "model": model, "messages": [ {"role": "system", "content": lang_prompt}, {"role": "user", "content": prompt} ], "stream": False, "max_tokens": 512, "stop": [""], "temperature": 0.7, "top_p": 0.7, "top_k": 50, "frequency_penalty": 0.5, "n": 1, "response_format": {"type": "text"} # 改为 "text" 而不是 "json_object" } headers = { "Authorization": f"Bearer {SILICONFLOW_API_KEY}", "Content-Type": "application/json" } response = requests.post(SILICONFLOW_API_URL, json=payload, headers=headers) response.raise_for_status() response_data = response.json() if 'choices' in response_data and len(response_data['choices']) > 0: content = response_data['choices'][0]['message']['content'] # 尝试解析JSON,如果失败则直接返回内容 try: json_content = json.loads(content) return json_content.get('response', content) # 如果是JSON,尝试获取'response'字段 except json.JSONDecodeError: return content # 如果不是JSON,直接返回内容 else: return "无法生成回应。请重试。" async def wait_for_audio_file(task_id, timeout=30): start_time = time.time() while time.time() - start_time < timeout: audio_files = [f for f in os.listdir(AUDIO_BASE_PATH) if f.startswith(task_id) and f.endswith('.wav')] if audio_files: return audio_files[0] await asyncio.sleep(0.5) return None async def concatenate_audio_files(audio_files, output_path): combined = AudioSegment.empty() for audio_file in audio_files: segment = AudioSegment.from_wav(audio_file) combined += segment combined.export(output_path, format="wav") async def team_discussion_generator(topic, max_rounds, selected_roles, language): # 加载选中的团队成员 team_members = load_team_members(language, selected_roles) discussion_topics = { 'zh': "讨论主题", 'en': "Discussion Topic", 'ko': "토론 주제" } topic_header = discussion_topics.get(language, discussion_topics['zh']) discussion = [f"{topic_header}: {topic}"] yield json.dumps({"topic": f"{topic_header}: {topic}\n", "audio": None}) for round in range(max_rounds): for role, info in team_members.items(): prompt = f"""你是一个团队中的{info['post']},{info['personality']}。 团队正在讨论以下问题:"{topic}",请用{language}语言回答 当前讨论进展: {''.join(discussion)} 请根据你的角色和特点,对这个问题发表你的看法或对其他成员的观点进行回应,在多轮讨论中,应该根据讨论历史不断优化回答。 在回答时打招呼内容不要带有自己的角色和别人的角色,语气应避免单调和机械,尽可能口语化,回答内容保持简洁,控制在150字以内 """ response = await get_ai_response(prompt, language=language) discussion.append(f"\n{info['name']}({info['post']}):{response}") # 生成语音 tts_task_id = str(uuid.uuid4()) producer.send(KAFKA_TTS_TOPIC, { 'task_id': tts_task_id, 'text': response, 'text_hash': tts_task_id, 'voice': info['voice'] }) # 等待音频生成完成 while True: task_status = redis_task_client.get(f"task_status:tts:{tts_task_id}") if task_status: status = task_status.decode('utf-8') if status == "completed": task_info = redis_task_client.get(f"task_info:tts:{tts_task_id}") if task_info: task_data = json.loads(task_info) text_hash = task_data['text_hash'] voice = task_data['voice'] redis_tts = voice_to_redis[voice] audio_info = redis_tts.get(f"tts:{text_hash}") if audio_info: audio_path = json.loads(audio_info)['path'] output = { "post": f"{info['post']}", "chunk": f"{info['name']}:{response}", "audio_task_id": tts_task_id, "avatar": info['avatar'] # 添加头像信息 } print(json.dumps({ "type": "output", "content": output }, ensure_ascii=False, indent=2)) yield json.dumps(output) break elif status == "failed": output = { "post": f"{info['post']}", "chunk": f"{info['name']}:{response}", "audio": None, "avatar": info['avatar'] # 添加头像信息 } yield json.dumps(output) break await asyncio.sleep(0.5) # 让主持人结束讨论 leader_info = team_members['主持人'] summary_prompts = { 'zh': "作为主持人,请用中文总结团队的讨论并给出最终的结论或建议。", 'en': "As the moderator, please summarize the team's discussion in English and provide final conclusions or recommendations.", 'ko': "사회자로서 팀 토론을 한국어로 요약하고 최종 결론이나 권장 사항을 제시해 주세요." } summary_prompt = f"""{summary_prompts.get(language, summary_prompts['zh'])}讨论内容如下: {''.join(discussion)} 你应该: 保持客观性,不添加个人观点或偏见; 使用简短明了语言,避免冗长或重复; 保留每个角色关键信息主要论点; 按照逻辑顺序组织信息,使总结易于理解; 根据文本的长度和复杂程度,调整总结的详细程度; 提供项目规划专家的流程图。 """ try: summary = await get_ai_response(summary_prompt, language=language) print(f"生成的总结: {summary}") # 添加日志 if not summary or len(summary.strip()) < 10: # 检查总结是否为空或过短 raise ValueError("生成的总结过短或为空") discussion.append(f"\n总结:{summary}") except Exception as e: print(f"生成总结时发生错误: {str(e)}") summary = "很抱歉,生成总结时遇到了问题。请查看之前的讨论内容作为参考。" discussion.append(f"\n总结:{summary}") # 生成总结的语音 summary_tts_task_id = str(uuid.uuid4()) producer.send(KAFKA_TTS_TOPIC, { 'task_id': summary_tts_task_id, 'text': summary, 'text_hash': summary_tts_task_id, 'voice': leader_info['voice'] }) # 等待总结音频生成完成 summary_audio_generated = False while not summary_audio_generated: task_status = redis_task_client.get(f"task_status:tts:{summary_tts_task_id}") if task_status: status = task_status.decode('utf-8') if status == "completed": task_info = redis_task_client.get(f"task_info:tts:{summary_tts_task_id}") if task_info: task_data = json.loads(task_info) text_hash = task_data['text_hash'] voice = task_data['voice'] redis_tts = voice_to_redis[voice] audio_info = redis_tts.get(f"tts:{text_hash}") if audio_info: audio_path = json.loads(audio_info)['path'] summary_json = { "post": leader_info['post'], "chunk": f"{leader_info['name']}:{summary}", "audio_task_id": summary_tts_task_id, "avatar": leader_info['avatar'] } print(summary_json) yield json.dumps(summary_json) summary_audio_generated = True elif status == "failed": summary_json = { "post": leader_info['post'], "chunk": f"{leader_info['name']}:{summary}", "audio_task_id": None, "avatar": leader_info['avatar'] } yield json.dumps(summary_json) summary_audio_generated = True await asyncio.sleep(0.5) completion_json = {"chunk": "[DISCUSSION_COMPLETED]", "audio_task_id": summary_tts_task_id} print(f"发送讨论完成信号: {json.dumps(completion_json)}") # 打印讨论完成信号 yield json.dumps(completion_json) async def discussion_stream(topic: str, max_rounds: int, session_id: str, selected_roles: List[str], language: str): discussion_content = [] audio_files = [] async for chunk in team_discussion_generator(topic, max_rounds, selected_roles, language): chunk_data = json.loads(chunk) if "topic" in chunk_data: yield f"data: {chunk}\n\n".encode('utf-8') elif "chunk" in chunk_data: if chunk_data["chunk"] == "[DISCUSSION_COMPLETED]": # Concatenate audio files output_path = os.path.join(AUDIO_BASE_PATH, f"{session_id}_combined.wav") await concatenate_audio_files(audio_files, output_path) # 保存讨论内容到对应语言的 Redis discussion = { "topic": topic, "content": discussion_content, "timestamp": int(datetime.now(timezone.utc).timestamp()), # Current timestamp in seconds "combined_audio_path": output_path, "language": language } redis_session_clients[language].set(session_id, json.dumps(discussion)) print(f"讨论容已保存 {language} 数据库,合并音频路径: {output_path}") completion = {"chunk": "[AUDIO_COMPLETED]"} yield json.dumps(completion) break else: discussion_content.append(chunk_data) if "audio_task_id" in chunk_data and chunk_data["audio_task_id"]: audio_path = await get_audio_path(chunk_data["audio_task_id"]) if audio_path: audio_files.append(audio_path) yield f"data: {chunk}\n\n".encode('utf-8') else: # 处理其他可能的数据结构 discussion_content.append(chunk_data) yield f"data: {chunk}\n\n".encode('utf-8') async def get_audio_path(task_id): task_info = redis_task_client.get(f"task_info:tts:{task_id}") if task_info: task_data = json.loads(task_info) voice = task_data['voice'] redis_tts = voice_to_redis[voice] audio_info = redis_tts.get(f"tts:{task_data['text_hash']}") if audio_info: return json.loads(audio_info)['path'] return None @user_app.get("/api/start-discussion") @user_app.post("/api/start-discussion") async def start_discussion(request: DiscussionRequest): if not request.topic: raise HTTPException(status_code=400, detail="Topic is required") if not request.session_id: raise HTTPException(status_code=400, detail="Session ID is required") if not request.selected_roles: raise HTTPException(status_code=400, detail="At least one role must be selected") if not request.language or request.language not in ['zh', 'en', 'ko']: request.language = 'zh' # 默认使用中文 # 根据选择的语言和角色加载团队成员信息 global TEAM_MEMBERS TEAM_MEMBERS = load_team_members(request.language) return StreamingResponse( discussion_stream(request.topic, request.max_rounds, request.session_id, request.selected_roles, request.language), media_type="text/event-stream" ) # 获取指定任务ID的音频文件路径 @user_app.get("/api/get-audio/{task_id}", response_model=List[str]) async def get_audio(task_id: str): # 从Redis中获取任务状态 task_status = redis_task_client.get(f"task_status:tts:{task_id}") if task_status: status = task_status.decode('utf-8') if status == "completed": # 如果任务已完成,获取任务信息 task_info = redis_task_client.get(f"task_info:tts:{task_id}") if task_info: task_data = json.loads(task_info) voice = task_data['voice'] redis_tts = voice_to_redis[voice] # 从Redis中获取音频信息 audio_info = redis_tts.get(f"tts:{task_data['text_hash']}") if audio_info: audio_path = json.loads(audio_info)['path'] # 检查音频文件是否存在 if os.path.exists(audio_path): # 返回音频文件的URL路径 return [f"/audio/{os.path.basename(audio_path)}"] elif status == "queued" or status == "processing": # 如果任务正在队列中或处理中,返回202状态码 raise HTTPException(status_code=202, detail="音频文件正在生成中") else: # 如果任务状态异常,返回500错误 raise HTTPException(status_code=500, detail="任务处理失败") # 如果任务不存在,返回404错误 raise HTTPException(status_code=404, detail="任务不存在") # 提供音频文件的下载服务 @user_app.get("/api/get-combined-audio/{session_id}") async def get_combined_audio(session_id: str, language: str = 'zh'): if language not in redis_session_clients: raise HTTPException(status_code=400, detail="Unsupported language") session_data = redis_session_clients[language].get(session_id) if session_data: session_info = json.loads(session_data.decode('utf-8')) audio_path = session_info.get("combined_audio_path") print(audio_path) if audio_path and os.path.exists(audio_path): return FileResponse(audio_path, media_type="audio/wav", filename=f"{session_id}_combined.wav") raise HTTPException(status_code=404, detail="Combined audio not found") # 获取可用的音色列表 @user_app.get("/getvoice") async def get_available_voices(): # 定义有效的音色列表 valid_voices = ["default", "girl", "woman", "man", "leijun", "dufu", "hejiong", "mahuateng", "lidan", "yuhua", "liuzhenyun", "dabing", "luoxiang", "xuzhiyuan"] # 返回可用的音色列表 return {"available_voices": valid_voices} @user_app.get("/tts_result/{task_id}") async def get_tts_result(task_id: str): task_status = redis_task_client.get(f"task_status:tts:{task_id}") if task_status: status = task_status.decode('utf-8') return {"status": status} else: raise HTTPException(status_code=404, detail="任务不存在") @user_app.get("/tts_audio/{task_id}") async def get_tts_audio(task_id: str): task_info = redis_task_client.get(f"task_info:tts:{task_id}") if task_info: task_data = json.loads(task_info) text_hash = task_data['text_hash'] voice = task_data['voice'] redis_tts = voice_to_redis[voice] audio_info = redis_tts.get(f"tts:{text_hash}") if audio_info: audio_path = json.loads(audio_info)['path'] return FileResponse(audio_path, media_type="audio/wav") raise HTTPException(status_code=404, detail="音频文件不存在") @user_app.get("/api/get-discussion/{session_id}") async def get_discussion(session_id: str, language: str = 'zh'): if language not in redis_session_clients: raise HTTPException(status_code=400, detail="Unsupported language") discussion_content = redis_session_clients[language].get(session_id) if discussion_content: try: discussion_data = json.loads(discussion_content) return JSONResponse(content=discussion_data) except json.JSONDecodeError: raise HTTPException(status_code=500, detail="Invalid JSON data in Redis") raise HTTPException(status_code=404, detail="Discussion not found") @user_app.get("/api/get-session-id") async def get_session_id(language: str = 'zh'): if language not in redis_session_clients: raise HTTPException(status_code=400, detail="Unsupported language") redis_client = redis_session_clients[language] result = [] try: all_keys = redis_client.keys('*') session_ids = [key.decode('utf-8') for key in all_keys] for session_id in session_ids: value = redis_client.get(session_id) if value: value_json = json.loads(value.decode('utf-8')) if isinstance(value_json, dict) and 'topic' in value_json and 'timestamp' in value_json: topic = value_json['topic'] utc_time = datetime.fromtimestamp(value_json['timestamp'], tz=timezone.utc) timestamp = utc_time.strftime("%Y-%m-%d %H:%M:%S UTC") result.append({ "session_id": session_id, "topic": topic, "timestamp": timestamp, "language": language }) except Exception as e: print(f"Error fetching {language} session data: {e}") raise HTTPException(status_code=500, detail=f"Error fetching session data for {language}") if not result: raise HTTPException(status_code=404, detail=f"No session data found for {language}") return JSONResponse(content={"sessions": result}) # 确保这个路径是正确的,并且文件夹存在 AVATAR_BASE_PATH = "/obscura/task/avatar" # 在挂载之前添加一些调试信息 print(f"Avatar base path: {AVATAR_BASE_PATH}") print(f"Avatar directory exists: {os.path.exists(AVATAR_BASE_PATH)}") print(f"Avatar directory contents: {os.listdir(AVATAR_BASE_PATH)}") # 将头像文件夹挂载到 /avatar 路径 user_app.mount("/avatar", StaticFiles(directory=AVATAR_BASE_PATH), name="avatar") # 数据库设置 DB_USER = os.getenv("DB_USER") DB_PASSWORD = os.getenv("DB_PASSWORD") DB_HOST = os.getenv("DB_HOST") DB_NAME = os.getenv("DB_NAME") DB_PORT = os.getenv("DB_PORT") SQLALCHEMY_DATABASE_URL = f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" engine = create_engine(SQLALCHEMY_DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() # 密码哈希配置 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # OAuth2密码流配置 oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # 用户模型 class User(Base): __tablename__ = "user" uuid = Column(String(36), primary_key=True, index=True) username = Column(String(50), unique=True, index=True, nullable=False) email = Column(String(100), unique=True, index=True) hashed_password = Column(String(255)) created_at = Column(DateTime, server_default=func.now()) last_login = Column(DateTime) is_active = Column(Boolean, default=False) # 创建数据库表 Base.metadata.create_all(bind=engine) # 用户创建模型 class UserCreate(BaseModel): username: str email: EmailStr password: str # 数据库会话依赖 def get_db(): db = SessionLocal() try: yield db finally: db.close() # 密码验证函数 def verify_password(plain_password, hashed_password): return pwd_context.verify(plain_password, hashed_password) # 密码哈希函数 def get_password_hash(password): return pwd_context.hash(password) # 根据用户名获取用户 def get_user(db: Session, username: str): return db.query(User).filter(User.username == username).first() # 用户认证函数 def authenticate_user(db: Session, username: str, password: str): user = get_user(db, username) if not user: return False if not verify_password(password, user.hashed_password): return False return user # 获取当前用户 async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)): user_info = redis_client.get(f"session:{token}") if not user_info: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的token") user_data = json.loads(user_info) user = get_user(db, user_data['username']) if user is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在") return user @user_app.post("/register") def register(user: UserCreate, db: Session = Depends(get_db)): db_user = get_user(db, username=user.username) if db_user: raise HTTPException(status_code=400, detail="用户名已被注册") hashed_password = get_password_hash(user.password) new_user = User( uuid=str(uuid.uuid4()), username=user.username, email=user.email, hashed_password=hashed_password ) db.add(new_user) db.commit() db.refresh(new_user) return {"message": "用户创建成功"} @user_app.post("/token") def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): user = authenticate_user(db, form_data.username, form_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名或密码不正确", headers={"WWW-Authenticate": "Bearer"}, ) session_id = str(uuid.uuid4()) user_info = { "id": user.uuid, "username": user.username, "is_active": user.is_active } redis_client.setex(f"session:{session_id}", 604800, json.dumps(user_info)) # 设置7天过期 user.last_login = datetime.now(timezone.utc) db.commit() return {"access_token": session_id, "token_type": "bearer"} @user_app.get("/me") async def read_users_me(current_user: User = Depends(get_current_user)): return current_user @user_app.get("/login") async def login_page(): return FileResponse("login.html") @user_app.post("/logout") async def logout(token: str = Depends(oauth2_scheme)): redis_client.delete(f"session:{token}") return {"message": "Successfully logged out"} @user_app.get("/user-info") async def get_user_info(current_user: User = Depends(get_current_user)): return { "username": current_user.username, "email": current_user.email, "avatar": current_user.username[0].upper() # 使用用户名的首字母 } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)