Files
api/api_chat/ollamas.py
T
2025-01-12 06:15:15 +00:00

136 lines
4.0 KiB
Python

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import httpx
import json
import redis
from typing import List, Dict, Optional
import logging
import ollama
import uuid
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Redis连接
redis_client = redis.Redis(host='222.186.10.253', port=6379, db=14, password="Obscura@2024")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GenerateRequest(BaseModel):
model: Optional[str] = "qwen2.5:3b"
prompt: str
class RawGenerateRequest(BaseModel):
model: Optional[str] = "qwen2.5:3b"
prompt: str
system_prompt: Optional[str] = None
stream: Optional[bool] = False
raw: Optional[bool] = False
format: Optional[str] = None
options: Optional[Dict] = None
class GenerateResponse(BaseModel):
response: dict
request_id: str
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
logger.info(f"收到请求: {request}")
request_id = str(uuid.uuid4())
try:
response = ollama.chat(model=request.model, messages=[{"role": "user", "content": request.prompt}])
full_response = response['message']['content']
request_data = {
"model": request.model,
"prompt": request.prompt,
"response": full_response
}
redis_client.set(f"request:{request_id}", json.dumps(request_data))
response_data = {
"response": full_response,
"model": request.model
}
return GenerateResponse(response=response_data, request_id=request_id)
except Exception as e:
logger.error(f"发生错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/generate")
async def generate_without_history(request: RawGenerateRequest):
"""
处理无历史记录的生成请求。
参数:
- request: RawGenerateRequest对象,包含生成请求的所有参数。
返回:
- 包含生成结果的字典。
"""
try:
response = ollama.generate(
model=request.model,
prompt=request.prompt,
system=request.system_prompt,
format=request.format,
options=request.options,
stream=request.stream
)
response_data = {
"model": request.model,
"response": response['response'],
"done": True,
"context": response.get('context'),
"total_duration": response.get('total_duration'),
"load_duration": response.get('load_duration'),
"prompt_eval_count": response.get('prompt_eval_count'),
"prompt_eval_duration": response.get('prompt_eval_duration'),
"eval_count": response.get('eval_count'),
"eval_duration": response.get('eval_duration')
}
request_id = str(uuid.uuid4())
redis_client.set(f"request:{request_id}", json.dumps(response_data))
return response_data
except Exception as e:
logger.error(f"发生未预期的错误: {e}")
logger.exception("详细错误信息:")
raise HTTPException(status_code=500, detail=f"处理Ollama请求时发生错误: {str(e)}")
@app.get("/request/{request_id}", response_model=Dict)
async def get_request(request_id: str):
request_data = redis_client.get(f"request:{request_id}")
if request_data:
return json.loads(request_data)
raise HTTPException(status_code=404, detail="请求未找到")
@app.get("/models")
async def list_models():
return ollama.list()
@app.get("/models/{model_name}")
async def show_model(model_name: str):
return ollama.show(model_name)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7000)