Files
api/api_chat/before/mini3_api.py
T
2025-01-12 06:15:15 +00:00

182 lines
6.0 KiB
Python

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import requests
import json
from typing import List, Tuple
from kafka import KafkaConsumer, TopicPartition
from concurrent.futures import ThreadPoolExecutor
import threading
import asyncio
import redis
import uuid
import logging
import uvicorn
from dotenv import load_dotenv
import os
import torch
from modelscope import AutoModelForCausalLM, AutoTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
print(f"Using device: {device}")
# Load MiniCPM3-4B model
path = "/home/zydi/worker_chat/api/OpenBMB/MiniCPM3-4B"
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True)
# 加载 .env 文件
load_dotenv()
# CORS 配置
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS').split(',')
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Kafka 设置
KAFKA_BROKER = os.getenv('KAFKA_BROKER')
KAFKA_TOPIC = os.getenv('KAFKA_MINI3_TOPIC')
KAFKA_CONSUMER_GROUP = 'mini3_group'
KAFKA_CONSUMER_NUM = 1
# Redis 设置
REDIS_HOST = os.getenv('REDIS_HOST')
REDIS_PORT = int(os.getenv('REDIS_PORT'))
REDIS_DB = int(os.getenv('REDIS_MINI3_DB'))
REDIS_PASSWORD = os.getenv('REDIS_PASSWORD')
# 创建Redis客户端
redis_client = redis.Redis(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_DB,
password=REDIS_PASSWORD # 使用密码进行认证
)
# 创建Kafka消费者
def create_kafka_consumer():
return KafkaConsumer(
bootstrap_servers=KAFKA_BROKER,
auto_offset_reset='earliest',
enable_auto_commit=True,
group_id=KAFKA_CONSUMER_GROUP,
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
)
# Kafka消费者函数
def kafka_consumer(consumer, consumer_id):
# 获取消费者分配的分区
consumer.subscribe([KAFKA_TOPIC])
partitions = consumer.assignment()
logger.info(f"消费者 {consumer_id} 被分配了以下分区: {[p.partition for p in partitions]}")
for message in consumer:
partition = message.partition
offset = message.offset
chat_request = message.value # 直接使用 message.value,它已经是一个字典
session_id = chat_request['session_id']
query = chat_request['query']
logger.info(f"消费者 {consumer_id} 正在处理来自分区 {partition} 的消息:")
asyncio.run(process_chat_request(chat_request))
# 启动Kafka消费者线程
def start_kafka_consumers(num_consumers=KAFKA_CONSUMER_NUM):
consumers = []
for i in range(num_consumers):
consumer = create_kafka_consumer()
consumer_thread = threading.Thread(target=kafka_consumer, args=(consumer, i), daemon=True)
consumer_thread.start()
consumers.append((consumer, consumer_thread))
return consumers
DEFAULT_SYSTEM_PROMPT = "你是一个智能助手,请用尽可能简短、简洁、友好的方式回答问题。输入的所有内容都来自于语音识别输入,因此可能会出现各种错误,请尽可能猜测用户的意思"
class ChatRequest(BaseModel):
session_id: str
query: str
model: str = "minicpm3-4b"
class ChatResponse(BaseModel):
response: str
history: List[Tuple[str, str]]
# 处理聊天请求的异步函数
async def process_chat_request(chat_request):
try:
response = await chat(ChatRequest(**chat_request))
print(f"Processed message for session {chat_request['session_id']}: {response}")
except Exception as e:
print(f"Error processing message for session {chat_request['session_id']}: {str(e)}")
@app.post("/mini3", response_model=ChatResponse)
async def chat(request: ChatRequest):
session_id = request.session_id
query = request.query
# 从Redis获取历史记录
history = json.loads(redis_client.get(session_id) or '[]')
# 构建包含历史对话的完整提示
messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
for past_query, past_response in history:
messages.append({"role": "user", "content": past_query})
messages.append({"role": "assistant", "content": past_response})
messages.append({"role": "user", "content": query})
try:
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
# 创建注意力掩码
attention_mask = model_inputs.ne(tokenizer.pad_token_id).long()
# 将输入移动到正确的设备(CPU或GPU)
model_inputs = model_inputs.to(device)
attention_mask = attention_mask.to(device)
model_outputs = model.generate(
model_inputs,
attention_mask=attention_mask,
max_new_tokens=1024,
top_p=0.7,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id, # 将pad_token_id设置为eos_token_id
do_sample=True
)
output_token_ids = model_outputs[0][len(model_inputs[0]):]
text_output = tokenizer.decode(output_token_ids, skip_special_tokens=True)
# 更新历史记录
history.append((query, text_output))
redis_client.set(session_id, json.dumps(history))
return ChatResponse(response=text_output, history=history)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/start_chat")
async def start_chat():
session_id = str(uuid.uuid4())
return {"session_id": session_id}
if __name__ == '__main__':
# 启动Kafka消费者线程
start_kafka_consumers()
# 启动FastAPI服务器
uvicorn.run(app, host="0.0.0.0", port=6003)