122 lines
4.1 KiB
Python
122 lines
4.1 KiB
Python
from kafka import KafkaConsumer
|
|
import json
|
|
import asyncio
|
|
import redis
|
|
import os
|
|
from dotenv import load_dotenv
|
|
import requests
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
# 加载 .env 文件
|
|
load_dotenv()
|
|
|
|
# Kafka 设置
|
|
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
|
KAFKA_CHAT_TOPIC = os.getenv('KAFKA_CHAT_TOPIC')
|
|
KAFKA_CONSUMER_GROUP = 'chat_group'
|
|
KAFKA_CONSUMER_NUM = 1 # 消费者数量
|
|
|
|
# Redis 设置
|
|
REDIS_HOST = os.getenv('REDIS_HOST')
|
|
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
|
REDIS_CHAT_DB = int(os.getenv('REDIS_CHAT_DB'))
|
|
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
|
REDIS_TASK_DB = int(os.getenv('REDIS_TASK_DB'))
|
|
|
|
# 创建Redis客户端
|
|
redis_client = redis.Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
db=REDIS_CHAT_DB,
|
|
password=REDIS_PASSWORD
|
|
)
|
|
|
|
redis_task_client = redis.Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
db=REDIS_TASK_DB,
|
|
password=REDIS_PASSWORD
|
|
)
|
|
|
|
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
|
|
|
|
# 创建Kafka消费者
|
|
def create_kafka_consumer():
|
|
return KafkaConsumer(
|
|
KAFKA_CHAT_TOPIC,
|
|
bootstrap_servers=KAFKA_BROKER,
|
|
auto_offset_reset='latest',
|
|
enable_auto_commit=True,
|
|
group_id=KAFKA_CONSUMER_GROUP,
|
|
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
|
)
|
|
|
|
async def process_chat_request(chat_request):
|
|
try:
|
|
task_id = chat_request['task_id']
|
|
session_id = chat_request['session_id']
|
|
query = chat_request['query']
|
|
model = chat_request.get('model', 'qwen2.5:3b')
|
|
|
|
# 设置任务状态为 "processing"
|
|
redis_task_client.set(f"chat:{task_id}:status", "processing")
|
|
|
|
# 从Redis获取历史记录 (使用 session_id)
|
|
history = json.loads(redis_client.get(f"chat:{session_id}") or '[]')
|
|
|
|
# 构建包含历史对话的完整提示
|
|
full_prompt = DEFAULT_SYSTEM_PROMPT + "\n"
|
|
for past_query, past_response in history:
|
|
full_prompt += f"用户: {past_query}\n助手: {past_response}\n"
|
|
full_prompt += f"用户: {query}\n助手:"
|
|
|
|
data = {
|
|
"model": model,
|
|
"prompt": full_prompt,
|
|
"stream": True,
|
|
"temperature": 0
|
|
}
|
|
|
|
response = requests.post("https://ffgregevrdcfyhtnhyudvr.myfastools.com/api/generate", json=data, stream=True)
|
|
response.raise_for_status()
|
|
|
|
text_output = ""
|
|
for line in response.iter_lines():
|
|
if line:
|
|
json_data = json.loads(line)
|
|
if 'response' in json_data:
|
|
text_output += json_data['response']
|
|
|
|
# 更新历史记录 (使用 session_id)
|
|
history.append((query, text_output))
|
|
redis_client.set(f"chat:{session_id}", json.dumps(history))
|
|
|
|
# 设置任务状态为 "completed" 并存储响应 (使用 task_id)
|
|
redis_task_client.set(f"chat:{task_id}:status", "completed")
|
|
redis_task_client.set(f"chat:{task_id}:response", text_output)
|
|
|
|
# 存储当前任务的结果到 REDIS_TASK_DB (db3)
|
|
redis_task_client.set(f"chat:{task_id}:result", json.dumps({"query": query, "response": text_output}))
|
|
|
|
print(f"处理完成 task_id {task_id}, session_id {session_id}: {text_output}")
|
|
|
|
except Exception as e:
|
|
print(f"处理 task {task_id} 时出错: {str(e)}")
|
|
# 设置任务状态为 "error"
|
|
redis_task_client.set(f"chat:{task_id}:status", "error")
|
|
redis_task_client.set(f"chat:{task_id}:error", str(e))
|
|
def kafka_consumer_thread(consumer_id):
|
|
consumer = create_kafka_consumer()
|
|
print(f"消费者 {consumer_id} 已启动")
|
|
for message in consumer:
|
|
chat_request = message.value
|
|
asyncio.run(process_chat_request(chat_request))
|
|
|
|
def main():
|
|
print("启动Kafka消费者处理聊天请求...")
|
|
with ThreadPoolExecutor(max_workers=KAFKA_CONSUMER_NUM) as executor:
|
|
for i in range(KAFKA_CONSUMER_NUM):
|
|
executor.submit(kafka_consumer_thread, i)
|
|
|
|
if __name__ == '__main__':
|
|
main() |