|
|
@@ -14,8 +14,10 @@ import threading
|
|
|
import queue
|
|
|
import time
|
|
|
import logging
|
|
|
+from datetime import datetime
|
|
|
from typing import Optional, List, Tuple, Dict, Any
|
|
|
from dataclasses import dataclass
|
|
|
+from pathlib import Path
|
|
|
|
|
|
from config import PANORAMA_CAMERA, DETECTION_CONFIG
|
|
|
from dahua_sdk import DahuaSDK, PTZCommand
|
|
|
@@ -420,6 +422,17 @@ class ObjectDetector:
|
|
|
self.config = DETECTION_CONFIG
|
|
|
self.device = 'cuda:0' if use_gpu else 'cpu'
|
|
|
|
|
|
+ # 检测图片保存配置
|
|
|
+ self._save_image_enabled = self.config.get('save_detection_image', False)
|
|
|
+ self._image_save_dir = Path(self.config.get('detection_image_dir', './detection_images'))
|
|
|
+ self._image_max_count = self.config.get('detection_image_max_count', 1000)
|
|
|
+ self._last_save_time = 0
|
|
|
+ self._save_interval = 1.0 # 最小保存间隔(秒),避免保存过于频繁
|
|
|
+
|
|
|
+ # 创建保存目录
|
|
|
+ if self._save_image_enabled:
|
|
|
+ self._ensure_save_dir()
|
|
|
+
|
|
|
# 根据扩展名自动判断模型类型
|
|
|
if model_path:
|
|
|
ext = os.path.splitext(model_path)[1].lower()
|
|
|
@@ -502,6 +515,96 @@ class ObjectDetector:
|
|
|
"""使用OpenCV加载模型"""
|
|
|
pass
|
|
|
|
|
|
+ def _ensure_save_dir(self):
|
|
|
+ """确保保存目录存在"""
|
|
|
+ try:
|
|
|
+ self._image_save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+ logger.info(f"检测图片保存目录: {self._image_save_dir}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"创建检测图片目录失败: {e}")
|
|
|
+ self._save_image_enabled = False
|
|
|
+
|
|
|
+ def _cleanup_old_images(self):
|
|
|
+ """清理旧图片,保持目录下图片数量不超过上限"""
|
|
|
+ try:
|
|
|
+ image_files = list(self._image_save_dir.glob("*.jpg"))
|
|
|
+ if len(image_files) > self._image_max_count:
|
|
|
+ # 按修改时间排序,删除最旧的
|
|
|
+ image_files.sort(key=lambda x: x.stat().st_mtime)
|
|
|
+ to_delete = image_files[:len(image_files) - self._image_max_count]
|
|
|
+ for f in to_delete:
|
|
|
+ f.unlink()
|
|
|
+ logger.info(f"已清理 {len(to_delete)} 张旧检测图片")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"清理旧图片失败: {e}")
|
|
|
+
|
|
|
+ def _save_detection_image(self, frame: np.ndarray, detections: List[DetectedObject]):
|
|
|
+ """
|
|
|
+ 保存带有检测标记的图片
|
|
|
+ Args:
|
|
|
+ frame: 原始图像
|
|
|
+ detections: 检测结果列表
|
|
|
+ """
|
|
|
+ if not self._save_image_enabled or not detections:
|
|
|
+ return
|
|
|
+
|
|
|
+ # 检查保存间隔
|
|
|
+ current_time = time.time()
|
|
|
+ if current_time - self._last_save_time < self._save_interval:
|
|
|
+ return
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 复制图像避免修改原图
|
|
|
+ marked_frame = frame.copy()
|
|
|
+
|
|
|
+ # 绘制检测结果
|
|
|
+ for det in detections:
|
|
|
+ x, y, w, h = det.bbox
|
|
|
+ # 绘制边界框(绿色)
|
|
|
+ cv2.rectangle(marked_frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
|
|
|
+
|
|
|
+ # 绘制标签背景
|
|
|
+ label = f"{det.class_name}: {det.confidence:.2f}"
|
|
|
+ (label_w, label_h), baseline = cv2.getTextSize(
|
|
|
+ label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2
|
|
|
+ )
|
|
|
+ cv2.rectangle(
|
|
|
+ marked_frame,
|
|
|
+ (x, y - label_h - 10),
|
|
|
+ (x + label_w, y),
|
|
|
+ (0, 255, 0),
|
|
|
+ -1
|
|
|
+ )
|
|
|
+
|
|
|
+ # 绘制标签文字(黑色)
|
|
|
+ cv2.putText(
|
|
|
+ marked_frame, label,
|
|
|
+ (x, y - 5),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6,
|
|
|
+ (0, 0, 0), 2
|
|
|
+ )
|
|
|
+
|
|
|
+ # 绘制中心点(红色)
|
|
|
+ cv2.circle(marked_frame, det.center, 5, (0, 0, 255), -1)
|
|
|
+
|
|
|
+
|
|
|
+ # 生成文件名(时间戳)
|
|
|
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
|
|
+ filename = f"detection_{timestamp}.jpg"
|
|
|
+ filepath = self._image_save_dir / filename
|
|
|
+
|
|
|
+ # 保存图片
|
|
|
+ cv2.imwrite(str(filepath), marked_frame, [cv2.IMWRITE_JPEG_QUALITY, 90])
|
|
|
+ self._last_save_time = current_time
|
|
|
+
|
|
|
+ logger.info(f"已保存检测图片: {filepath},检测到 {len(detections)} 个目标")
|
|
|
+
|
|
|
+ # 定期清理旧图片
|
|
|
+ self._cleanup_old_images()
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"保存检测图片失败: {e}")
|
|
|
+
|
|
|
def _letterbox(self, image, size=(640, 640)):
|
|
|
"""Letterbox 预处理"""
|
|
|
h0, w0 = image.shape[:2]
|
|
|
@@ -593,21 +696,24 @@ class ObjectDetector:
|
|
|
"""检测物体返回所有类别结果"""
|
|
|
if frame is None:
|
|
|
return []
|
|
|
-
|
|
|
+
|
|
|
if hasattr(self, 'rknn') and self.rknn is not None:
|
|
|
results = self._detect_rknn(frame)
|
|
|
if results:
|
|
|
self._log_detections("RKNN", results, frame)
|
|
|
+ self._save_detection_image(frame, results)
|
|
|
return results
|
|
|
elif hasattr(self, 'session') and self.session is not None:
|
|
|
results = self._detect_rknn(frame)
|
|
|
if results:
|
|
|
self._log_detections("ONNX", results, frame)
|
|
|
+ self._save_detection_image(frame, results)
|
|
|
return results
|
|
|
elif self.model is not None:
|
|
|
results = self._detect_yolo(frame)
|
|
|
if results:
|
|
|
self._log_detections("YOLO", results, frame)
|
|
|
+ self._save_detection_image(frame, results)
|
|
|
return results
|
|
|
else:
|
|
|
logger.error("[YOLO] 没有可用的检测模型")
|