353 lines
14 KiB
Python
353 lines
14 KiB
Python
import os
|
||
# 在导入其他库之前设置使用 CUDA 1
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||
|
||
import json
|
||
import time
|
||
from datetime import datetime
|
||
import redis
|
||
from deepface import DeepFace
|
||
import numpy as np
|
||
import gc
|
||
import re
|
||
from config import REDIS_CLIENTS, REDIS_IDENTITY, PATH_CONFIG
|
||
|
||
class FaceAnalysisSystem:
|
||
def __init__(self):
|
||
# Redis配置
|
||
self.redis_clients = REDIS_CLIENTS
|
||
# 身份信息数据库
|
||
self.identity_db = REDIS_IDENTITY
|
||
|
||
def get_face_embedding(self, img_path):
|
||
"""获取人脸embedding"""
|
||
try:
|
||
embedding_obj = DeepFace.represent(
|
||
img_path=img_path,
|
||
detector_backend="retinaface",
|
||
align=True,
|
||
model_name="Facenet512"
|
||
)
|
||
return embedding_obj[0]["embedding"] if embedding_obj else None
|
||
except Exception as e:
|
||
print(f"获取人脸embedding失败: {str(e)}")
|
||
return None
|
||
|
||
def find_identity(self, embedding):
|
||
"""在身份数据库中查找匹配的身份"""
|
||
try:
|
||
# 获取所有身份的embedding
|
||
all_identities = self.identity_db.keys("*")
|
||
best_match = None
|
||
best_similarity = -1
|
||
|
||
for identity_key in all_identities:
|
||
# 获取该身份的所有embedding
|
||
stored_data = json.loads(self.identity_db.get(identity_key))
|
||
|
||
# 如果存储的数据是列表(多个embedding)
|
||
if isinstance(stored_data, list):
|
||
# 对该身份的每个embedding进行比对
|
||
for face_data in stored_data:
|
||
stored_vector = np.array(face_data["embedding"])
|
||
|
||
# 计算余弦相似度
|
||
similarity = np.dot(embedding, stored_vector) / (
|
||
np.linalg.norm(embedding) * np.linalg.norm(stored_vector)
|
||
)
|
||
|
||
if similarity > best_similarity:
|
||
best_similarity = similarity
|
||
best_match = identity_key.decode()
|
||
|
||
# 如果相似度大于阈值,返回身份信息,否则返回unknown
|
||
if best_similarity > 0.72: # 可以调整阈值
|
||
return best_match, best_similarity
|
||
return "unknown", best_similarity
|
||
|
||
except Exception as e:
|
||
print(f"查找身份时出错: {str(e)}")
|
||
return "unknown", -1
|
||
|
||
class ImageMonitor:
|
||
def __init__(self, images_path):
|
||
self.images_path = images_path
|
||
self.system = FaceAnalysisSystem()
|
||
self.processed_images = self._load_processed_images()
|
||
self.error_images = []
|
||
self.error_image_cache = set()
|
||
|
||
def _load_processed_images(self):
|
||
"""从Redis加载已处理的图片记录"""
|
||
processed = set()
|
||
try:
|
||
# 遍历所有Redis客户端
|
||
for camera_id, redis_client in self.system.redis_clients.items():
|
||
# 获取所有相关的keys
|
||
keys = redis_client.keys("face_*")
|
||
for key in keys:
|
||
data = redis_client.get(key)
|
||
if data:
|
||
hour_results = json.loads(data)
|
||
# 从每个小时数据中提取已处理的文件名
|
||
for base_name in hour_results.keys():
|
||
# 构建完整的文件路径
|
||
full_path = os.path.join(self.images_path, camera_id, base_name)
|
||
# 添加原始文件和可能的裁剪版本
|
||
base_without_ext = os.path.splitext(base_name)[0]
|
||
related_files = [f for f in os.listdir(os.path.join(self.images_path, camera_id))
|
||
if f.startswith(base_without_ext)]
|
||
for related_file in related_files:
|
||
processed.add(os.path.join(self.images_path, camera_id, related_file))
|
||
|
||
print(f"从Redis加载了 {len(processed)} 个已处理的图片记录")
|
||
return processed
|
||
|
||
except Exception as e:
|
||
print(f"加载已处理图片记录时出错: {str(e)}")
|
||
return set()
|
||
|
||
def _get_redis_key(self, image_path):
|
||
"""生成Redis键值"""
|
||
try:
|
||
dir_name = os.path.basename(os.path.dirname(image_path))
|
||
file_name = os.path.basename(image_path)
|
||
|
||
# 从图片文件名中提取日期和时间,移除可能的后缀(_1, _2等)
|
||
base_name = re.sub(r'_\d+(?=\.(jpg|png))', '', file_name)
|
||
# 修改正则表达式以匹配更宽松的时间格式
|
||
match = re.search(r'(\w+)_(\d{8})_(\d{6})\.(jpg|png)', base_name)
|
||
if match:
|
||
camera_id = match.group(1)
|
||
date = match.group(2)
|
||
time = match.group(3)
|
||
hour = time[:2] # 从完整时间中提取小时
|
||
|
||
# 生成key: A01_20250105_1300
|
||
redis_key = f"face_{camera_id}_{date}_{hour}00"
|
||
return redis_key, base_name
|
||
|
||
print(f"文件名格式不匹配: {file_name}")
|
||
return None, None
|
||
|
||
except Exception as e:
|
||
print(f"生成Redis key失败: {str(e)}")
|
||
return None, None
|
||
|
||
def _get_base_images(self, image_path):
|
||
"""获取同一原始图片的所有裁剪图片"""
|
||
try:
|
||
dir_path = os.path.dirname(image_path)
|
||
base_name = os.path.splitext(os.path.basename(image_path))[0]
|
||
base_name = re.sub(r'_\d+$', '', base_name) # 移除数字后缀
|
||
|
||
related_images = []
|
||
for file_name in os.listdir(dir_path):
|
||
if file_name.startswith(base_name):
|
||
full_path = os.path.join(dir_path, file_name)
|
||
related_images.append(full_path)
|
||
|
||
return related_images
|
||
except Exception as e:
|
||
print(f"获取相关图片失败: {str(e)}")
|
||
return [image_path]
|
||
|
||
def process_new_image(self, image_path):
|
||
"""处理新图片"""
|
||
try:
|
||
if self._is_error_cached(image_path):
|
||
return False
|
||
|
||
# 获取同一原始图片的所有裁剪图片
|
||
related_images = self._get_base_images(image_path)
|
||
if not related_images:
|
||
return False
|
||
|
||
# 如果所有相关图片都已处理,则跳过
|
||
if all(img in self.processed_images for img in related_images):
|
||
return True
|
||
|
||
redis_key, base_name = self._get_redis_key(image_path)
|
||
if not redis_key or not base_name:
|
||
self._log_error(image_path, "Redis Key Error", "无法生成Redis key")
|
||
return False
|
||
|
||
# 存储每个身份的最佳匹配结果
|
||
identity_results = {}
|
||
timestamp = None
|
||
|
||
# 处理每个相关图片
|
||
for img in related_images:
|
||
if img in self.processed_images:
|
||
continue
|
||
|
||
if not os.path.exists(img):
|
||
self._log_error(img, "File Not Found", "图片文件不存在")
|
||
continue
|
||
|
||
# 检查文件大小
|
||
file_size = os.path.getsize(img)
|
||
if file_size == 0 or file_size < 10 * 1024:
|
||
self._log_error(img, "Invalid File Size", f"图片文件大小异常({file_size/1024:.2f}KB)")
|
||
continue
|
||
|
||
# 获取人脸embedding
|
||
embedding = self.system.get_face_embedding(img)
|
||
if embedding is None:
|
||
self._log_error(img, "Face Detection Error", "无法检测到人脸或提取特征")
|
||
continue
|
||
|
||
# 查找身份
|
||
identity, similarity = self.system.find_identity(embedding)
|
||
|
||
# 从文件名提取时间戳(如果还没有设置)
|
||
if not timestamp:
|
||
timestamp_match = re.search(r'(\d{4})(\d{2})(\d{2})_(\d{2})(\d{2})(\d{2})', os.path.basename(img))
|
||
if timestamp_match:
|
||
year, month, day, hour, minute, second = timestamp_match.groups()
|
||
timestamp = f"{year}-{month}-{day} {hour}:{minute}:{second}"
|
||
else:
|
||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
# 更新该身份的最佳匹配结果
|
||
if identity not in identity_results or similarity > identity_results[identity]["similarity"]:
|
||
identity_results[identity] = {
|
||
"similarity": float(similarity),
|
||
"file_name": os.path.basename(img)
|
||
}
|
||
|
||
self.processed_images.add(img)
|
||
|
||
# 如果有有效结果,保存到Redis
|
||
if identity_results:
|
||
dir_name = os.path.basename(os.path.dirname(image_path))
|
||
if dir_name in self.system.redis_clients:
|
||
redis_client = self.system.redis_clients[dir_name]
|
||
|
||
# 准备保存的数据
|
||
result_data = {
|
||
"face_analysis": {
|
||
identity: data for identity, data in identity_results.items()
|
||
},
|
||
"timestamp": timestamp
|
||
}
|
||
|
||
# 更新Redis数据
|
||
existing_data = redis_client.get(redis_key)
|
||
if existing_data:
|
||
hour_results = json.loads(existing_data)
|
||
hour_results[base_name] = result_data
|
||
else:
|
||
hour_results = {base_name: result_data}
|
||
|
||
json_str = json.dumps(hour_results, ensure_ascii=False)
|
||
redis_client.set(redis_key, json_str)
|
||
print(f"成功保存到Redis,key: {redis_key}")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
self._log_error(image_path, "Processing Error", str(e))
|
||
print(f"处理图片时发生错误 {image_path}: {str(e)}")
|
||
return False
|
||
finally:
|
||
gc.collect()
|
||
|
||
def _is_processed(self, image_path):
|
||
"""检查图片是否已处理"""
|
||
return image_path in self.processed_images
|
||
|
||
def _is_error_cached(self, image_path):
|
||
"""检查图片是否在错误缓存中"""
|
||
return image_path in self.error_image_cache
|
||
|
||
def _add_to_error_cache(self, image_path):
|
||
"""添加图片到错误缓存"""
|
||
self.error_image_cache.add(image_path)
|
||
|
||
def _log_error(self, image_path, error_type, error_message):
|
||
"""记录错误信息"""
|
||
if self._is_error_cached(image_path):
|
||
return
|
||
|
||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
error_info = {
|
||
"timestamp": current_time,
|
||
"image_path": image_path,
|
||
"error_type": error_type,
|
||
"error_message": error_message,
|
||
"file_size": os.path.getsize(image_path) if os.path.exists(image_path) else 0
|
||
}
|
||
self.error_images.append(error_info)
|
||
self._add_to_error_cache(image_path)
|
||
|
||
def _save_error_log(self):
|
||
"""保存错误日志"""
|
||
if not self.error_images:
|
||
return
|
||
|
||
try:
|
||
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
log_filename = f"image_errors_{current_time}.json"
|
||
|
||
with open(log_filename, 'w', encoding='utf-8') as f:
|
||
json.dump(self.error_images, f, ensure_ascii=False, indent=2)
|
||
print(f"\n异常图片记录已保存到: {log_filename}")
|
||
|
||
self.error_images = []
|
||
except Exception as e:
|
||
print(f"保存错误日志失败: {str(e)}")
|
||
|
||
def monitor_directories(self):
|
||
"""监控目录变化"""
|
||
try:
|
||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
print(f"开始监控目录: {self.images_path} [{current_time}]")
|
||
|
||
while True:
|
||
try:
|
||
for camera_dir in os.listdir(self.images_path):
|
||
camera_path = os.path.join(self.images_path, camera_dir)
|
||
if not os.path.isdir(camera_path):
|
||
continue
|
||
|
||
for image_file in os.listdir(camera_path):
|
||
if not image_file.lower().endswith(('.jpg', '.jpeg', '.png')):
|
||
continue
|
||
|
||
image_path = os.path.join(camera_path, image_file)
|
||
if not self._is_processed(image_path) and not self._is_error_cached(image_path):
|
||
print(f"处理图片: {image_path}")
|
||
if not self.process_new_image(image_path):
|
||
self._add_to_error_cache(image_path)
|
||
print(f"图片处理失败,已加入错误缓存: {image_path}")
|
||
continue
|
||
|
||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
print(f"[{current_time}] 等待新图片中...")
|
||
time.sleep(60) # 每分钟检查一次
|
||
|
||
except Exception as e:
|
||
print(f"监控过程出错: {str(e)}")
|
||
time.sleep(10)
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n检测到程序终止信号,正在保存错误日志...")
|
||
self._save_error_log()
|
||
print("程序已安全终止。")
|
||
except Exception as e:
|
||
print(f"\n程序异常终止: {str(e)}")
|
||
self._save_error_log()
|
||
raise
|
||
|
||
def main():
|
||
try:
|
||
images_path = "files/crop" # 设置crop目录路径
|
||
monitor = ImageMonitor(images_path)
|
||
monitor.monitor_directories()
|
||
|
||
except Exception as e:
|
||
print(f"\n未预期的错误: {str(e)}")
|
||
|
||
if __name__ == "__main__":
|
||
main() |