Files
zydi-web/test_history/qwen.py
T
2025-01-12 03:01:51 +00:00

325 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import json
import torch
from datetime import datetime
from PIL import Image
import io
import re
from decord import VideoReader
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
# 配置
QWEN_MODEL_PATH = "/obscura/models/qwen/Qwen2-VL-7B-Instruct"
# 初始化 Qwen 模型 (使用 cuda:0)
print("正在初始化 Qwen 模型 (cuda:0)...")
model = Qwen2VLForConditionalGeneration.from_pretrained(
QWEN_MODEL_PATH,
torch_dtype="auto",
device_map="cuda:0"
)
min_pixels = 128*28*28
max_pixels = 256*28*28
processor = AutoProcessor.from_pretrained(
QWEN_MODEL_PATH,
min_pixels=min_pixels,
max_pixels=max_pixels
)
# 在文件开头添加加载配置的代码
def load_config():
"""加载配置文件"""
try:
with open('info.json', 'r', encoding='utf-8') as f:
config = json.load(f)
return config
except Exception as e:
print(f"加载配置文件失败: {e}")
return {"actions": [], "environments": []}
# 加载配置
CONFIG = load_config()
class MediaAnalysisSystem:
def __init__(self):
self.MAX_NUM_FRAMES = 10
self.device = "cuda:0"
self.qwen_model = model
self.qwen_processor = processor
# 使用加载的配置
self.environments = CONFIG["environments"]
self.actions = CONFIG["actions"]
def encode_video(self, video_data):
def uniform_sample(l, n):
gap = len(l) / n
return [l[int(i * gap + gap / 2)] for i in range(n)]
video_file = io.BytesIO(video_data)
vr = VideoReader(video_file)
sample_fps = round(vr.get_avg_fps() / 1)
frame_idx = list(range(0, len(vr), sample_fps))
if len(frame_idx) > self.MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, self.MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
print('num frames:', len(frames))
return frames
def process_with_qwen(self, media_data, object_name, media_type='image'):
"""使用 Qwen 模型处理媒体"""
if media_type == 'video':
frames = self.encode_video(media_data)
media_content = {"type": "video", "video": frames, "fps": 1.0}
else:
image = Image.open(io.BytesIO(media_data))
media_content = {"type": "image", "image": image}
messages = [
{
"role": "user",
"content": [
media_content,
{"type": "text", "text": self._get_analysis_prompt(media_type)}
],
}
]
text = self.qwen_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.qwen_processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.device)
generated_ids = self.qwen_model.generate(**inputs, max_new_tokens=2048)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
answer = self.qwen_processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return {
"model": "qwen",
"original_answer": answer,
"extracted_info": self.extract_info(answer)
}
def _get_analysis_prompt(self, media_type):
"""获取分析提示词"""
return f"""你是一位视频描述专家,你擅长对视频进行详细的描述,请对这段监控视频进行详细分析,包括以下方面,并按照下面格式回答:
1. 环境场景
- 整体场景描述(室内/室外、光线条件等)
- 主要物品和家具列表
- 环境特征(如光线、整洁度等)
2. 人员统计
- 总人数:[数字]人
- 性别分布:[男性数量]/[女性数量]
(若无法确定准确人数,请注明"无法确定人数"
3. 人员特征分析
- 个人特征:性别、年龄段、着装、体态等
- 携带物品:详细描述随身物品及用途
- 表情/情绪状态
4. 行为分析
- 个人行为:移动方向、姿态、动作等
- 互动情况:人员之间的交互描述(若多人)
- 活动区域:人员活动的主要位置
5. 群体行为(若多人)
- 聚集形态
- 移动趋势
- 群体互动特点
6. 异常情况
- 可疑行为描述
- 异常活动标记
- 需要注意的安全隐患
请用清晰、有条理的格式描述,并突出重要发现。"""
def extract_info(self, answer):
"""提取中文信息"""
info = {
"environment": None,
"num_people": None,
"actions": [],
"objects": [],
"furniture": [],
"emotions": [],
"features": []
}
# 使用加载的环境列表
for env in self.environments:
if env in answer.lower():
info["environment"] = env
break
# 中文数字模式
people_patterns = [
r'(\d+)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一|二|三|四|五|六|七|八|九|十)\s*(人|个人|位|名|员工|用户|小朋友|成年人|女性|男性)',
r'(一个|几个)\s*(人|个人|员工|用户|小朋友|成年人|女性|男性)',
r'\s*(名|位)\s*(人|员工|用户|小朋友|成年人|女性|男性)?',
r'(男|女)(性|生|士)',
r'(成年|未成年|青少年|老年)\s*(人|群体)',
r'(员工|职工|工人|学生|顾客|观众|游客|乘客)',
r'(群众|民众|大众|公众)',
r'(男女|老少|老幼|大人|小孩)'
]
for pattern in people_patterns:
match = re.search(pattern, answer)
if match:
if match.group(1).isdigit():
info["num_people"] = int(match.group(1))
elif match.group(1) in ['一个', '']:
info["num_people"] = 1
else:
num_word_to_digit = {
'': 2, '': 3, '': 4, '': 5,
'': 6, '': 7, '': 8, '': 9, '': 10
}
info["num_people"] = num_word_to_digit.get(match.group(1), 0)
break
# 使用加载的动作列表
for action in self.actions:
if action in answer:
info["actions"].append(action)
emotions = [
"钦佩", "赞赏", "欣赏","关心", "高兴", "", "乐观", "感激", "释然", "骄傲", "愉悦",
"愤怒", "烦恼", "焦虑", "尴尬", "失望", "厌恶", "恐惧", "悲伤", "懊悔", "羞耻","发呆",
"困惑", "好奇", "欲望", "惊讶", "实事求是", "中性", "赞叹","平静","放松","专注","思考",
]
objects = ["水瓶", "办公用品", "文件", "电脑","风扇","鼠标","键盘","纸巾","","","袋子","盒子","水杯","杯子","马克杯","玻璃杯","文件夹","书包","书架","文件柜","手机"]
furniture = ["椅子", "桌子", "咖啡桌", "文件柜", "", "沙发","柜子","架子","摄像头","靠垫","办公椅","电视","白板","显示器","置物架","文件架"]
features = ["戴眼镜","不戴眼镜","长发","短发","长头发","短头发","戴帽子","不戴帽子","戴口罩","不戴口罩","男性","女性","","","","","","","成年人"]
for obj in objects:
if obj in answer:
info["objects"].append(obj)
for item in furniture:
if item in answer:
info["furniture"].append(item)
for feature in features:
if feature in answer:
info["features"].append(feature)
for emotion in emotions:
if emotion in answer:
info["emotions"].append(emotion)
return info
def process_video_folder(system, folder_path, output_path=None):
"""处理文件夹中的所有视频文件并保存结果"""
valid_extensions = {'.mp4', '.avi', '.mov', '.mkv'}
results = {}
if not os.path.exists(folder_path):
raise MediaAnalysisError(f"错误:文件夹 '{folder_path}' 不存在")
if output_path is None:
output_path = os.getcwd()
elif not os.path.exists(output_path):
os.makedirs(output_path)
video_files = [
f for f in os.listdir(folder_path)
if os.path.splitext(f)[1].lower() in valid_extensions
]
if not video_files:
raise MediaAnalysisError(f"错误:在文件夹 '{folder_path}' 中未找到支持的视频文件")
print(f"\n找到 {len(video_files)} 个视频文件,开始处理...\n")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
folder_name = os.path.basename(os.path.normpath(folder_path))
output_file = os.path.join(output_path, f"analysis_results_{folder_name}_{timestamp}.json")
for i, video_file in enumerate(video_files, 1):
video_path = os.path.join(folder_path, video_file)
print(f"正在处理 ({i}/{len(video_files)}): {video_file}")
try:
with open(video_path, "rb") as f:
video_data = f.read()
results[video_file] = {"video_analysis": {}}
# 只使用 Qwen 处理视频
print(f"使用 Qwen 处理视频: {video_file}")
qwen_result = system.process_with_qwen(video_data, video_file, media_type='video')
results[video_file]["video_analysis"]["qwen-7B"] = {
"original_answer": qwen_result["original_answer"],
"extracted_info": qwen_result["extracted_info"]
}
# 添加时间戳
results[video_file]["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# 保存结果
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"✓ 成功处理并保存: {video_file}")
# 每个视频处理完后清理内存
if torch.cuda.is_available():
torch.cuda.empty_cache()
import gc
gc.collect()
except Exception as e:
print(f"✗ 处理失败 {video_file}: {str(e)}")
results[video_file] = {"error": str(e)}
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"\n所有分析结果已保存到: {output_file}")
return results
class MediaAnalysisError(Exception):
"""自定义媒体分析异常类"""
pass
def main():
try:
system = MediaAnalysisSystem()
# 添加文件夹路径输入处理
folder_path = input("请输入视频文件夹路径: ").strip()
output_path = input("请输入结果保存路径 (直接回车使用当前目录): ").strip()
# 如果用户没有输入输出路径,则使用None(将使用当前目录)
output_path = output_path if output_path else None
# 处理文件夹中的视频
results = process_video_folder(system, folder_path, output_path)
# 显示处理统计
success_count = sum(1 for r in results.values() if "error" not in r)
print(f"\n处理完成!成功: {success_count}/{len(results)}")
except MediaAnalysisError as e:
print(f"\n错误: {str(e)}")
except Exception as e:
print(f"\n未预期的错误: {str(e)}")
if __name__ == "__main__":
main()