136 lines
4.0 KiB
Python
136 lines
4.0 KiB
Python
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import httpx
|
|
import json
|
|
import redis
|
|
from typing import List, Dict, Optional
|
|
import logging
|
|
import ollama
|
|
import uuid
|
|
|
|
app = FastAPI()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Redis连接
|
|
redis_client = redis.Redis(host='222.186.10.253', port=6379, db=14, password="Obscura@2024")
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class GenerateRequest(BaseModel):
|
|
model: Optional[str] = "qwen2.5:3b"
|
|
prompt: str
|
|
|
|
class RawGenerateRequest(BaseModel):
|
|
model: Optional[str] = "qwen2.5:3b"
|
|
prompt: str
|
|
system_prompt: Optional[str] = None
|
|
stream: Optional[bool] = False
|
|
raw: Optional[bool] = False
|
|
format: Optional[str] = None
|
|
options: Optional[Dict] = None
|
|
|
|
class GenerateResponse(BaseModel):
|
|
response: dict
|
|
request_id: str
|
|
|
|
@app.post("/generate", response_model=GenerateResponse)
|
|
async def generate(request: GenerateRequest):
|
|
logger.info(f"收到请求: {request}")
|
|
|
|
request_id = str(uuid.uuid4())
|
|
|
|
try:
|
|
response = ollama.chat(model=request.model, messages=[{"role": "user", "content": request.prompt}])
|
|
full_response = response['message']['content']
|
|
|
|
request_data = {
|
|
"model": request.model,
|
|
"prompt": request.prompt,
|
|
"response": full_response
|
|
}
|
|
|
|
redis_client.set(f"request:{request_id}", json.dumps(request_data))
|
|
|
|
response_data = {
|
|
"response": full_response,
|
|
"model": request.model
|
|
}
|
|
|
|
return GenerateResponse(response=response_data, request_id=request_id)
|
|
|
|
except Exception as e:
|
|
logger.error(f"发生错误: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/api/generate")
|
|
async def generate_without_history(request: RawGenerateRequest):
|
|
"""
|
|
处理无历史记录的生成请求。
|
|
|
|
参数:
|
|
- request: RawGenerateRequest对象,包含生成请求的所有参数。
|
|
|
|
返回:
|
|
- 包含生成结果的字典。
|
|
"""
|
|
try:
|
|
response = ollama.generate(
|
|
model=request.model,
|
|
prompt=request.prompt,
|
|
system=request.system_prompt,
|
|
format=request.format,
|
|
options=request.options,
|
|
stream=request.stream
|
|
)
|
|
|
|
response_data = {
|
|
"model": request.model,
|
|
"response": response['response'],
|
|
"done": True,
|
|
"context": response.get('context'),
|
|
"total_duration": response.get('total_duration'),
|
|
"load_duration": response.get('load_duration'),
|
|
"prompt_eval_count": response.get('prompt_eval_count'),
|
|
"prompt_eval_duration": response.get('prompt_eval_duration'),
|
|
"eval_count": response.get('eval_count'),
|
|
"eval_duration": response.get('eval_duration')
|
|
}
|
|
|
|
request_id = str(uuid.uuid4())
|
|
redis_client.set(f"request:{request_id}", json.dumps(response_data))
|
|
|
|
return response_data
|
|
|
|
except Exception as e:
|
|
logger.error(f"发生未预期的错误: {e}")
|
|
logger.exception("详细错误信息:")
|
|
raise HTTPException(status_code=500, detail=f"处理Ollama请求时发生错误: {str(e)}")
|
|
|
|
@app.get("/request/{request_id}", response_model=Dict)
|
|
async def get_request(request_id: str):
|
|
request_data = redis_client.get(f"request:{request_id}")
|
|
if request_data:
|
|
return json.loads(request_data)
|
|
raise HTTPException(status_code=404, detail="请求未找到")
|
|
|
|
@app.get("/models")
|
|
async def list_models():
|
|
return ollama.list()
|
|
|
|
@app.get("/models/{model_name}")
|
|
async def show_model(model_name: str):
|
|
return ollama.show(model_name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=7000) |