Files
api/api_chat/chat.py
T
2025-01-12 06:15:15 +00:00

122 lines
4.1 KiB
Python

from kafka import KafkaConsumer
import json
import asyncio
import redis
import os
from dotenv import load_dotenv
import requests
from concurrent.futures import ThreadPoolExecutor
# 加载 .env 文件
load_dotenv()
# Kafka 设置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
KAFKA_CONSUMER_GROUP = 'chat_group'
KAFKA_CONSUMER_NUM = 1 # 消费者数量
# Redis 设置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
# 创建Redis客户端
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_CHAT_DB,
password=REDIS_PASSWORD
)
redis_task_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_TASK_DB,
password=REDIS_PASSWORD
)
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
# 创建Kafka消费者
def create_kafka_consumer():
return KafkaConsumer(
KAFKA_CHAT_TOPIC,
bootstrap_servers=KAFKA_BROKER,
auto_offset_reset='latest',
enable_auto_commit=True,
group_id=KAFKA_CONSUMER_GROUP,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
async def process_chat_request(chat_request):
try:
task_id = chat_request['task_id']
session_id = chat_request['session_id']
query = chat_request['query']
model = chat_request.get('model', 'qwen2.5:3b')
# 设置任务状态为 "processing"
redis_task_client.set(f"chat:{task_id}:status", "processing")
# 从Redis获取历史记录 (使用 session_id)
history = json.loads(redis_client.get(f"chat:{session_id}") or '[]')
# 构建包含历史对话的完整提示
full_prompt = DEFAULT_SYSTEM_PROMPT + "\n"
for past_query, past_response in history:
full_prompt += f"用户: {past_query}\n助手: {past_response}\n"
full_prompt += f"用户: {query}\n助手:"
data = {
"model": model,
"prompt": full_prompt,
"stream": True,
"temperature": 0
}
response = requests.post("https://ffgregevrdcfyhtnhyudvr.myfastools.com/api/generate", json=data, stream=True)
response.raise_for_status()
text_output = ""
for line in response.iter_lines():
if line:
json_data = json.loads(line)
if 'response' in json_data:
text_output += json_data['response']
# 更新历史记录 (使用 session_id)
history.append((query, text_output))
redis_client.set(f"chat:{session_id}", json.dumps(history))
# 设置任务状态为 "completed" 并存储响应 (使用 task_id)
redis_task_client.set(f"chat:{task_id}:status", "completed")
redis_task_client.set(f"chat:{task_id}:response", text_output)
# 存储当前任务的结果到 REDIS_TASK_DB (db3)
redis_task_client.set(f"chat:{task_id}:result", json.dumps({"query": query, "response": text_output}))
print(f"处理完成 task_id {task_id}, session_id {session_id}: {text_output}")
except Exception as e:
print(f"处理 task {task_id} 时出错: {str(e)}")
# 设置任务状态为 "error"
redis_task_client.set(f"chat:{task_id}:status", "error")
redis_task_client.set(f"chat:{task_id}:error", str(e))
def kafka_consumer_thread(consumer_id):
consumer = create_kafka_consumer()
print(f"消费者 {consumer_id} 已启动")
for message in consumer:
chat_request = message.value
asyncio.run(process_chat_request(chat_request))
def main():
print("启动Kafka消费者处理聊天请求...")
with ThreadPoolExecutor(max_workers=KAFKA_CONSUMER_NUM) as executor:
for i in range(KAFKA_CONSUMER_NUM):
executor.submit(kafka_consumer_thread, i)
if __name__ == '__main__':
main()