356 lines
13 KiB
Python
356 lines
13 KiB
Python
import os
|
|
import json
|
|
import uuid
|
|
from datetime import datetime, timedelta
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from kafka import KafkaProducer, KafkaConsumer
|
|
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
|
from qwen_vl_utils import process_vision_info
|
|
from decord import VideoReader, cpu
|
|
from PIL import Image
|
|
from redis import Redis
|
|
import io
|
|
import re
|
|
import torch
|
|
from contextlib import asynccontextmanager
|
|
import threading
|
|
|
|
app = FastAPI()
|
|
qwenvl_app = FastAPI()
|
|
app.mount("/qwenvl", qwenvl_app)
|
|
torch.cuda.set_device(1)
|
|
# CORS设置
|
|
ALLOWED_ORIGINS = ['https://beta.obscura.work']
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=ALLOWED_ORIGINS,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# 配置
|
|
MODEL_PATH = "/home/zydi/models/qwen/Qwen2-VL-2B-Instruct"
|
|
KAFKA_BROKER = "222.186.10.253:9092"
|
|
KAFKA_TOPIC = "qwenvl"
|
|
KAFKA_GROUP_ID = "qwenvl_group"
|
|
|
|
REDIS_HOST = "222.186.10.253"
|
|
REDIS_PORT = 6379
|
|
REDIS_PASSWORD = "Obscura@2024"
|
|
REDIS_DB = 8
|
|
|
|
UPLOAD_DIR = "/www/wwwroot/beta.obscura.work/upload_files/upload"
|
|
RESULT_DIR = "/www/wwwroot/beta.obscura.work/upload_files/result"
|
|
MAX_FILE_AGE = timedelta(hours=1)
|
|
|
|
# 确保目录存在
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|
os.makedirs(RESULT_DIR, exist_ok=True)
|
|
|
|
# 初始化 Kafka
|
|
producer = KafkaProducer(bootstrap_servers=[KAFKA_BROKER])
|
|
consumer = KafkaConsumer(
|
|
KAFKA_TOPIC,
|
|
bootstrap_servers=[KAFKA_BROKER],
|
|
group_id=KAFKA_GROUP_ID,
|
|
auto_offset_reset='earliest',
|
|
enable_auto_commit=True,
|
|
value_deserializer=lambda x: json.loads(x.decode('utf-8'))
|
|
)
|
|
|
|
# 初始化 Redis
|
|
redis_client = Redis(
|
|
host=REDIS_HOST,
|
|
port=REDIS_PORT,
|
|
password=REDIS_PASSWORD,
|
|
db=REDIS_DB
|
|
)
|
|
|
|
|
|
# 初始化模型
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
MODEL_PATH, torch_dtype="auto", device_map="cuda:1"
|
|
)
|
|
|
|
min_pixels = 128*28*28
|
|
max_pixels = 512*28*28
|
|
processor = AutoProcessor.from_pretrained(MODEL_PATH, min_pixels=min_pixels, max_pixels=max_pixels)
|
|
|
|
class MediaAnalysisSystem:
|
|
def __init__(self, model, processor):
|
|
self.model = model
|
|
self.processor = processor
|
|
self.MAX_NUM_FRAMES = 10
|
|
|
|
def encode_video(self, video_data):
|
|
def uniform_sample(l, n):
|
|
gap = len(l) / n
|
|
return [l[int(i * gap + gap / 2)] for i in range(n)]
|
|
|
|
video_file = io.BytesIO(video_data)
|
|
vr = VideoReader(video_file)
|
|
sample_fps = round(vr.get_avg_fps() / 1)
|
|
frame_idx = list(range(0, len(vr), sample_fps))
|
|
if len(frame_idx) > self.MAX_NUM_FRAMES:
|
|
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
|
|
frames = vr.get_batch(frame_idx).asnumpy()
|
|
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
|
print('num frames:', len(frames))
|
|
return frames
|
|
|
|
def process_media(self, media_data, object_name, media_type='image'):
|
|
if not media_data:
|
|
raise ValueError(f"Empty {media_type} data for {object_name}")
|
|
|
|
print(f"Processing {media_type}: {object_name}, data size: {len(media_data)} bytes")
|
|
|
|
if media_type == 'video':
|
|
frames = self.encode_video(media_data)
|
|
media_content = {"type": "video", "video": frames, "fps": 1.0}
|
|
else: # image
|
|
image = Image.open(io.BytesIO(media_data))
|
|
media_content = {"type": "image", "image": image}
|
|
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
media_content,
|
|
{"type": "text", "text": "用中文尽可能详细地描述这个" + ("视频" if media_type == "video" else "图片") + ",包括场景、人物数量、行为变化等。"},
|
|
],
|
|
}
|
|
]
|
|
|
|
text = self.processor.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
image_inputs, video_inputs = process_vision_info(messages)
|
|
inputs = self.processor(
|
|
text=[text],
|
|
images=image_inputs,
|
|
videos=video_inputs,
|
|
padding=True,
|
|
return_tensors="pt",
|
|
)
|
|
inputs = inputs.to('cuda:1')
|
|
generated_ids = self.model.generate(**inputs, max_new_tokens=512)
|
|
generated_ids_trimmed = [
|
|
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
]
|
|
answer = self.processor.batch_decode(
|
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
)[0]
|
|
|
|
extracted_info = self.extract_info(answer)
|
|
|
|
result = {
|
|
"original_answer": answer,
|
|
"extracted_info": extracted_info,
|
|
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
}
|
|
|
|
if media_type == 'video':
|
|
result["num_frames"] = len(frames)
|
|
|
|
return result
|
|
|
|
def process_video(self, video_data, object_name):
|
|
return self.process_media(video_data, object_name, media_type='video')
|
|
|
|
def process_image(self, image_data, object_name):
|
|
return self.process_media(image_data, object_name, media_type='image')
|
|
|
|
@staticmethod
|
|
def extract_time_from_filename(object_name):
|
|
filename = os.path.basename(object_name)
|
|
time_str = filename.split('_')[0] + '_' + filename.split('_')[1].split('.')[0]
|
|
|
|
try:
|
|
start_time = datetime.strptime(time_str, "%Y%m%d_%H%M%S")
|
|
end_time = start_time + timedelta(seconds=10)
|
|
return start_time, end_time
|
|
except ValueError:
|
|
print(f"无法从文件名 '{filename}' 解析时间。使用默认时间。")
|
|
return datetime.now(), datetime.now() + timedelta(seconds=10)
|
|
|
|
@staticmethod
|
|
def extract_info(answer):
|
|
info = {
|
|
"environment": None,
|
|
"num_people": None,
|
|
"actions": [],
|
|
"interactions": [],
|
|
"objects": [],
|
|
"furniture": []
|
|
}
|
|
|
|
environments = ["办公室", "室内", "室外", "会议室"]
|
|
for env in environments:
|
|
if env in answer.lower():
|
|
info["environment"] = env
|
|
break
|
|
|
|
people_patterns = [
|
|
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
|
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
|
|
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
|
|
r'几\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
|
|
r'(男|女)(性|生|士)',
|
|
r'(成年|未成年|青少年|老年)\s*(人|群体)',
|
|
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
|
|
r'(群众|民众|大众|公众)',
|
|
r'(男女|老少|老幼|大人|小孩)'
|
|
]
|
|
for pattern in people_patterns:
|
|
match = re.search(pattern, answer)
|
|
if match:
|
|
if match.group(1).isdigit():
|
|
info["num_people"] = int(match.group(1))
|
|
elif match.group(1) in ['一个', '一']:
|
|
info["num_people"] = 1
|
|
else:
|
|
num_word_to_digit = {
|
|
'二': 2, '三': 3, '四': 4, '五': 5,
|
|
'六': 6, '七': 7, '八': 8, '九': 9, '十': 10
|
|
}
|
|
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
|
|
break
|
|
|
|
actions = ["坐", "站", "摔倒", "跳舞", "转身", "摔", "倒", "倒下", "躺下", "转身", "跳跃", "跳", "躺", "睡", "说话"]
|
|
interactions = ["互动", "交流", "身体语言", "交谈", "讨论", "开会"]
|
|
objects = ["水瓶", "办公用品", "文件", "电脑"]
|
|
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "床", "沙发"]
|
|
|
|
for item_list, key in [(actions, "actions"), (interactions, "interactions"), (objects, "objects"), (furniture, "furniture")]:
|
|
for item in item_list:
|
|
if item in answer:
|
|
info[key].append(item)
|
|
|
|
return info
|
|
|
|
# 初始化 MediaAnalysisSystem
|
|
media_analysis_system = MediaAnalysisSystem(model, processor)
|
|
|
|
async def process_file(file: UploadFile, file_type: str):
|
|
content = await file.read()
|
|
# 获取原始文件的后缀
|
|
original_extension = os.path.splitext(file.filename)[1]
|
|
|
|
# 生成新的文件名,包含 UUID 和原始后缀
|
|
filename = f"qwenvl_{uuid.uuid4()}{original_extension}"
|
|
file_path = os.path.join(UPLOAD_DIR, filename)
|
|
with open(file_path, "wb") as f:
|
|
f.write(content)
|
|
|
|
producer.send(KAFKA_TOPIC, json.dumps({
|
|
"filename": filename,
|
|
"type": file_type
|
|
}).encode('utf-8'))
|
|
|
|
redis_key = f"{file_type}_result:{filename}"
|
|
redis_client.set(redis_key, json.dumps({"status": "queued"}))
|
|
|
|
return {"message": f"{file_type.capitalize()} uploaded and queued for processing", "filename": filename}
|
|
|
|
@qwenvl_app.post("/upload")
|
|
async def upload_file(file: UploadFile = File(...)):
|
|
try:
|
|
file_type = "image" if file.content_type.startswith("image") else "video"
|
|
return await process_file(file, file_type)
|
|
except Exception as e:
|
|
return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
|
|
@qwenvl_app.post("/analyze_video")
|
|
async def analyze_video(file: UploadFile = File(...)):
|
|
try:
|
|
return await process_file(file, "video")
|
|
except Exception as e:
|
|
return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
|
|
@qwenvl_app.post("/analyze_image")
|
|
async def analyze_image(file: UploadFile = File(...)):
|
|
try:
|
|
return await process_file(file, "image")
|
|
except Exception as e:
|
|
return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
|
|
def process_task():
|
|
for message in consumer:
|
|
try:
|
|
if isinstance(message.value, dict):
|
|
task = message.value
|
|
else:
|
|
task = json.loads(message.value.decode('utf-8'))
|
|
|
|
filename = task['filename']
|
|
file_type = task['type']
|
|
|
|
file_path = os.path.join(UPLOAD_DIR, filename)
|
|
|
|
redis_key = f"{file_type}_result:{filename}"
|
|
redis_client.set(redis_key, json.dumps({"status": "processing"}))
|
|
|
|
with open(file_path, 'rb') as f:
|
|
file_data = f.read()
|
|
|
|
if file_type == "video":
|
|
result = media_analysis_system.process_video(file_data, filename)
|
|
elif file_type == "image":
|
|
result = media_analysis_system.process_image(file_data, filename)
|
|
|
|
# 保存结果到 JSON 文件
|
|
result_file_path = os.path.join(RESULT_DIR, f"{filename}.json")
|
|
with open(result_file_path, 'w', encoding='utf-8') as f:
|
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
|
|
# 将结果存储在 Redis 中
|
|
redis_client.set(redis_key, json.dumps({
|
|
"status": "completed",
|
|
"result": result
|
|
}))
|
|
|
|
except Exception as e:
|
|
print(f"Error processing task: {str(e)}")
|
|
if 'filename' in locals() and 'file_type' in locals():
|
|
redis_key = f"{file_type}_result:{filename}"
|
|
redis_client.set(redis_key, json.dumps({"status": "failed", "error": str(e)}))
|
|
else:
|
|
print("Error occurred before task details were extracted")
|
|
|
|
@qwenvl_app.get("/result/{filename}")
|
|
async def get_result(filename: str):
|
|
for file_type in ["video", "image"]:
|
|
redis_key = f"{file_type}_result:{filename}"
|
|
result = redis_client.get(redis_key)
|
|
if result:
|
|
result_json = json.loads(result)
|
|
|
|
if result_json.get("status") == "queued":
|
|
return {"status": "queued", "message": "Your request is in the queue and will be processed soon."}
|
|
elif result_json.get("status") == "processing":
|
|
return {"status": "processing", "message": "Your request is being processed."}
|
|
else:
|
|
return result_json
|
|
|
|
raise HTTPException(status_code=404, detail="Result not found")
|
|
|
|
async def listen_redis_changes():
|
|
pubsub = redis_client.pubsub()
|
|
pubsub.psubscribe('__keyspace@5__:*_result:*')
|
|
|
|
for message in pubsub.listen():
|
|
if message['type'] == 'pmessage':
|
|
key = message['channel'].decode('utf-8').split(':')[-1]
|
|
print(f"Key changed: {key}")
|
|
|
|
if __name__ == "__main__":
|
|
# 在后台线程中启动Kafka消费者
|
|
consumer_thread = threading.Thread(target=process_task, daemon=True)
|
|
consumer_thread.start()
|
|
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=7005) |