182 lines
6.0 KiB
Python
182 lines
6.0 KiB
Python
from fastapi import FastAPI, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
import requests
|
|
import json
|
|
from typing import List, Tuple
|
|
from kafka import KafkaConsumer, TopicPartition
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import threading
|
|
import asyncio
|
|
import redis
|
|
import uuid
|
|
import logging
|
|
import uvicorn
|
|
from dotenv import load_dotenv
|
|
import os
|
|
import torch
|
|
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
app = FastAPI()
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
torch.cuda.set_device(device)
|
|
print(f"Using device: {device}")
|
|
|
|
# Load MiniCPM3-4B model
|
|
path = "/home/zydi/worker_chat/api/OpenBMB/MiniCPM3-4B"
|
|
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True)
|
|
|
|
# 加载 .env 文件
|
|
load_dotenv()
|
|
# CORS 配置
|
|
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
|
|
|
|
# 添加CORS中间件
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=ALLOWED_ORIGINS,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Kafka 设置
|
|
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
|
|
KAFKA_TOPIC = os.getenv('KAFKA_MINI3_TOPIC')
|
|
KAFKA_CONSUMER_GROUP = 'mini3_group'
|
|
KAFKA_CONSUMER_NUM = 1
|
|
|
|
# Redis 设置
|
|
REDIS_HOST = os.getenv('REDIS_HOST')
|
|
REDIS_PORT = int(os.getenv('REDIS_PORT'))
|
|
REDIS_DB = int(os.getenv('REDIS_MINI3_DB'))
|
|
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
|
|
|
|
# 创建Redis客户端
|
|
redis_client = redis.Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
db=REDIS_DB,
|
|
password=REDIS_PASSWORD # 使用密码进行认证
|
|
)
|
|
# 创建Kafka消费者
|
|
def create_kafka_consumer():
|
|
return KafkaConsumer(
|
|
bootstrap_servers=KAFKA_BROKER,
|
|
auto_offset_reset='earliest',
|
|
enable_auto_commit=True,
|
|
group_id=KAFKA_CONSUMER_GROUP,
|
|
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
|
)
|
|
|
|
# Kafka消费者函数
|
|
def kafka_consumer(consumer, consumer_id):
|
|
# 获取消费者分配的分区
|
|
consumer.subscribe([KAFKA_TOPIC])
|
|
partitions = consumer.assignment()
|
|
|
|
logger.info(f"消费者 {consumer_id} 被分配了以下分区: {[p.partition for p in partitions]}")
|
|
|
|
for message in consumer:
|
|
partition = message.partition
|
|
offset = message.offset
|
|
chat_request = message.value # 直接使用 message.value,它已经是一个字典
|
|
session_id = chat_request['session_id']
|
|
query = chat_request['query']
|
|
|
|
logger.info(f"消费者 {consumer_id} 正在处理来自分区 {partition} 的消息:")
|
|
|
|
asyncio.run(process_chat_request(chat_request))
|
|
|
|
# 启动Kafka消费者线程
|
|
def start_kafka_consumers(num_consumers=KAFKA_CONSUMER_NUM):
|
|
consumers = []
|
|
for i in range(num_consumers):
|
|
consumer = create_kafka_consumer()
|
|
consumer_thread = threading.Thread(target=kafka_consumer, args=(consumer, i), daemon=True)
|
|
consumer_thread.start()
|
|
consumers.append((consumer, consumer_thread))
|
|
return consumers
|
|
|
|
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
|
|
|
|
class ChatRequest(BaseModel):
|
|
session_id: str
|
|
query: str
|
|
model: str = "minicpm3-4b"
|
|
|
|
class ChatResponse(BaseModel):
|
|
response: str
|
|
history: List[Tuple[str, str]]
|
|
|
|
# 处理聊天请求的异步函数
|
|
async def process_chat_request(chat_request):
|
|
try:
|
|
response = await chat(ChatRequest(**chat_request))
|
|
print(f"Processed message for session {chat_request['session_id']}: {response}")
|
|
except Exception as e:
|
|
print(f"Error processing message for session {chat_request['session_id']}: {str(e)}")
|
|
|
|
@app.post("/mini3", response_model=ChatResponse)
|
|
async def chat(request: ChatRequest):
|
|
session_id = request.session_id
|
|
query = request.query
|
|
|
|
# 从Redis获取历史记录
|
|
history = json.loads(redis_client.get(session_id) or '[]')
|
|
|
|
# 构建包含历史对话的完整提示
|
|
messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
|
|
for past_query, past_response in history:
|
|
messages.append({"role": "user", "content": past_query})
|
|
messages.append({"role": "assistant", "content": past_response})
|
|
messages.append({"role": "user", "content": query})
|
|
|
|
try:
|
|
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
|
|
|
|
# 创建注意力掩码
|
|
attention_mask = model_inputs.ne(tokenizer.pad_token_id).long()
|
|
|
|
# 将输入移动到正确的设备(CPU或GPU)
|
|
model_inputs = model_inputs.to(device)
|
|
attention_mask = attention_mask.to(device)
|
|
|
|
model_outputs = model.generate(
|
|
model_inputs,
|
|
attention_mask=attention_mask,
|
|
max_new_tokens=1024,
|
|
top_p=0.7,
|
|
temperature=0.7,
|
|
pad_token_id=tokenizer.eos_token_id, # 将pad_token_id设置为eos_token_id
|
|
do_sample=True
|
|
)
|
|
|
|
output_token_ids = model_outputs[0][len(model_inputs[0]):]
|
|
text_output = tokenizer.decode(output_token_ids, skip_special_tokens=True)
|
|
|
|
# 更新历史记录
|
|
history.append((query, text_output))
|
|
redis_client.set(session_id, json.dumps(history))
|
|
|
|
return ChatResponse(response=text_output, history=history)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/start_chat")
|
|
async def start_chat():
|
|
session_id = str(uuid.uuid4())
|
|
return {"session_id": session_id}
|
|
|
|
if __name__ == '__main__':
|
|
# 启动Kafka消费者线程
|
|
start_kafka_consumers()
|
|
|
|
# 启动FastAPI服务器
|
|
uvicorn.run(app, host="0.0.0.0", port=6003) |