safety_coordinator.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. """
  2. 安全联动控制器
  3. 整合安全检测、事件推送功能
  4. """
  5. import time
  6. import threading
  7. import queue
  8. from typing import Optional, List, Dict, Tuple, Callable
  9. from dataclasses import dataclass
  10. from enum import Enum
  11. import numpy as np
  12. import cv2
  13. from config import (
  14. COORDINATOR_CONFIG,
  15. SAFETY_DETECTION_CONFIG,
  16. EVENT_PUSHER_CONFIG,
  17. EVENT_LISTENER_CONFIG,
  18. SYSTEM_CONFIG
  19. )
  20. from safety_detector import (
  21. SafetyDetector, SafetyDetection, PersonSafetyStatus,
  22. SafetyViolationType, draw_safety_result
  23. )
  24. from event_pusher import EventPusher, EventListener, SafetyEvent, EventType
  25. class CoordinatorState(Enum):
  26. """控制器状态"""
  27. IDLE = 0 # 空闲
  28. DETECTING = 1 # 检测中
  29. TRACKING = 2 # 跟踪中
  30. ALERTING = 3 # 告警中
  31. @dataclass
  32. class AlertRecord:
  33. """告警记录"""
  34. track_id: int # 跟踪ID
  35. violation_type: str # 违规类型
  36. description: str # 描述
  37. frame: Optional[np.ndarray] # 图像
  38. timestamp: float # 时间戳
  39. pushed: bool = False # 是否已推送
  40. class SafetyCoordinator:
  41. """安全联动控制器:协调摄像头、安全检测、事件推送、PTZ跟踪"""
  42. def __init__(self, camera, config: Dict = None, ptz_camera=None, calibrator=None):
  43. self.camera = camera
  44. self.config = config or {}
  45. self.ptz = ptz_camera # PTZ球机(可选)
  46. self.calibrator = calibrator # 校准器(可选)
  47. self.detector = None
  48. self.event_pusher = None
  49. self.event_listener = None
  50. self.state = CoordinatorState.IDLE
  51. self.state_lock = threading.Lock()
  52. self.running = False
  53. self.worker_thread = None
  54. self.alert_records: List[AlertRecord] = []
  55. self.alert_cooldown = {}
  56. # 告警冷却时间(按违规类型)
  57. self._violation_cooldown = {}
  58. self.stats = {
  59. 'frames_processed': 0,
  60. 'persons_detected': 0,
  61. 'violations_detected': 0,
  62. 'events_pushed': 0,
  63. 'ptz_commands_sent': 0,
  64. 'start_time': None
  65. }
  66. self.stats_lock = threading.Lock()
  67. self.on_violation_detected: Optional[Callable] = None
  68. self.on_frame_processed: Optional[Callable] = None
  69. self._init_components()
  70. def _init_components(self):
  71. """初始化各组件"""
  72. # 从 SYSTEM_CONFIG 读取功能开关
  73. enable_detection = SYSTEM_CONFIG.get('enable_detection', True)
  74. enable_safety_detection = SYSTEM_CONFIG.get('enable_safety_detection', True)
  75. enable_event_push = SYSTEM_CONFIG.get('enable_event_push', True)
  76. # 安全检测器
  77. if enable_detection and enable_safety_detection:
  78. try:
  79. self.detector = SafetyDetector(
  80. model_path=SAFETY_DETECTION_CONFIG.get('model_path'),
  81. use_gpu=SAFETY_DETECTION_CONFIG.get('use_gpu', True),
  82. conf_threshold=SAFETY_DETECTION_CONFIG.get('conf_threshold', 0.5),
  83. person_threshold=SAFETY_DETECTION_CONFIG.get('person_threshold', 0.8)
  84. )
  85. print("安全检测器初始化成功")
  86. except Exception as e:
  87. print(f"安全检测器初始化失败: {e}")
  88. else:
  89. print("安全检测功能已禁用")
  90. # 事件推送器
  91. if enable_event_push:
  92. try:
  93. self.event_pusher = EventPusher(EVENT_PUSHER_CONFIG)
  94. print("事件推送器初始化成功")
  95. except Exception as e:
  96. print(f"事件推送器初始化失败: {e}")
  97. else:
  98. print("事件推送功能已禁用")
  99. # 事件监听器
  100. if EVENT_LISTENER_CONFIG.get('enabled', True):
  101. try:
  102. self.event_listener = EventListener(EVENT_LISTENER_CONFIG)
  103. print("事件监听器初始化成功")
  104. except Exception as e:
  105. print(f"事件监听器初始化失败: {e}")
  106. def start(self) -> bool:
  107. """启动控制器"""
  108. if self.running:
  109. return True
  110. if self.event_pusher:
  111. self.event_pusher.start()
  112. if self.event_listener:
  113. self.event_listener.start()
  114. self.running = True
  115. self.worker_thread = threading.Thread(target=self._worker, daemon=True)
  116. self.worker_thread.start()
  117. # PTZ跟踪已禁用
  118. with self.stats_lock:
  119. self.stats['start_time'] = time.time()
  120. print("安全联动控制器已启动")
  121. return True
  122. def stop(self):
  123. """停止控制器"""
  124. self.running = False
  125. if self.worker_thread:
  126. self.worker_thread.join(timeout=3)
  127. # PTZ跟踪已禁用
  128. if self.event_pusher:
  129. self.event_pusher.stop()
  130. if self.event_listener:
  131. self.event_listener.stop()
  132. self._print_stats()
  133. print("安全联动控制器已停止")
  134. def _worker(self):
  135. """工作线程"""
  136. # 优先使用 detection_fps,默认每秒2帧
  137. detection_fps = SAFETY_DETECTION_CONFIG.get('detection_fps', 2)
  138. detection_interval = 1.0 / detection_fps # 根据FPS计算间隔
  139. last_detection_time = 0
  140. detection_run_count = 0
  141. detection_violation_count = 0
  142. frame_count = 0
  143. last_log_time = time.time()
  144. heartbeat_interval = 30.0
  145. last_no_detect_log_time = 0
  146. import logging
  147. sc_logger = logging.getLogger(__name__)
  148. if self.detector is None:
  149. sc_logger.warning("[安全检测] ⚠️ 安全检测器未初始化! 安全检测不可用")
  150. else:
  151. sc_logger.info(f"[安全检测] ✓ 安全检测器已就绪, 检测帧率={detection_fps}fps(间隔={detection_interval:.2f}s)")
  152. while self.running:
  153. try:
  154. current_time = time.time()
  155. frame = self.camera.get_frame() if self.camera else None
  156. if frame is None:
  157. time.sleep(0.01)
  158. continue
  159. frame_count += 1
  160. self._update_stats('frames_processed')
  161. if current_time - last_log_time >= heartbeat_interval:
  162. stats = self.get_stats()
  163. state_str = self.state.name if hasattr(self.state, 'name') else str(self.state)
  164. sc_logger.info(
  165. f"[安全检测] 状态={state_str}, "
  166. f"检测轮次={detection_run_count}(有人={detection_violation_count}), "
  167. f"帧数={frame_count}"
  168. )
  169. frame_count = 0
  170. last_log_time = current_time
  171. if current_time - last_detection_time >= detection_interval:
  172. last_detection_time = current_time
  173. detection_run_count += 1
  174. result = self._process_frame_with_logging(frame, detection_run_count, detection_violation_count, last_no_detect_log_time, sc_logger)
  175. detection_violation_count = result
  176. time.sleep(0.01)
  177. except Exception as e:
  178. sc_logger.error(f"[安全检测] 处理错误: {e}")
  179. time.sleep(0.1)
  180. def _process_frame_with_logging(self, frame: np.ndarray, run_count: int, violation_count: int, last_no_detect_time: float, sc_logger) -> int:
  181. """处理帧并返回更新的violation_count"""
  182. if self.detector is None:
  183. return violation_count
  184. self._set_state(CoordinatorState.DETECTING)
  185. detections = self.detector.detect(frame)
  186. status_list = self.detector.check_safety(frame, detections)
  187. self._update_stats('persons_detected', len(status_list))
  188. # 轨迹追踪已禁用
  189. has_violation = False
  190. for status in status_list:
  191. if status.is_violation:
  192. self._handle_violation(status, frame)
  193. has_violation = True
  194. if has_violation:
  195. violation_count += 1
  196. if not status_list:
  197. current_time = time.time()
  198. if current_time - last_no_detect_time >= 30.0:
  199. sc_logger.info(
  200. f"[安全检测] · YOLO检测运行正常, 本轮未检测到人员 "
  201. f"(累计检测{run_count}轮, 违规{violation_count}轮)"
  202. )
  203. if self.on_frame_processed:
  204. self.on_frame_processed(frame, detections, status_list)
  205. return violation_count
  206. def _process_frame(self, frame: np.ndarray):
  207. """处理帧"""
  208. if self.detector is None:
  209. return
  210. self._set_state(CoordinatorState.DETECTING)
  211. # 安全检测
  212. detections = self.detector.detect(frame)
  213. status_list = self.detector.check_safety(frame, detections)
  214. self._update_stats('persons_detected', len(status_list))
  215. # 检查违规(轨迹追踪已禁用)
  216. for status in status_list:
  217. if status.is_violation:
  218. self._handle_violation(status, frame)
  219. # 回调
  220. if self.on_frame_processed:
  221. self.on_frame_processed(frame, detections, status_list)
  222. # 轨迹追踪已禁用 - _update_tracks 和 _cleanup_tracks 方法已移除
  223. def _handle_violation(self, status: PersonSafetyStatus, frame: np.ndarray):
  224. """处理违规"""
  225. current_time = time.time()
  226. # 检查冷却时间(按违规类型)
  227. violation_key = status.get_violation_desc()
  228. cooldown = SAFETY_DETECTION_CONFIG.get('alert_cooldown', 3.0)
  229. if violation_key in self.alert_cooldown:
  230. if current_time - self.alert_cooldown[violation_key] < cooldown:
  231. return
  232. # 记录告警
  233. self.alert_cooldown[violation_key] = current_time
  234. description = status.get_violation_desc()
  235. violation_type = status.violation_types[0].value if status.violation_types else "未知"
  236. # 裁剪人体区域
  237. x1, y1, x2, y2 = status.person_bbox
  238. margin = 20
  239. x1 = max(0, x1 - margin)
  240. y1 = max(0, y1 - margin)
  241. x2 = min(frame.shape[1], x2 + margin)
  242. y2 = min(frame.shape[0], y2 + margin)
  243. person_image = frame[y1:y2, x1:x2].copy()
  244. record = AlertRecord(
  245. track_id=0, # 轨迹追踪已禁用
  246. violation_type=violation_type,
  247. description=description,
  248. frame=person_image,
  249. timestamp=current_time
  250. )
  251. self.alert_records.append(record)
  252. self._update_stats('violations_detected')
  253. # PTZ跟踪已禁用
  254. # 回调
  255. if self.on_violation_detected:
  256. self.on_violation_detected(status, frame)
  257. # 推送事件
  258. if self.event_pusher:
  259. self.event_pusher.push_safety_violation(
  260. description=description,
  261. image=person_image,
  262. track_id=0, # 轨迹追踪已禁用
  263. confidence=status.person_conf
  264. )
  265. self._update_stats('events_pushed')
  266. print(f"[告警] {description}")
  267. # PTZ跟踪已禁用 - _track_violator_ptz 和 _ptz_worker 方法已移除
  268. def _set_state(self, state: CoordinatorState):
  269. """设置状态"""
  270. with self.state_lock:
  271. self.state = state
  272. def get_state(self) -> CoordinatorState:
  273. """获取状态"""
  274. with self.state_lock:
  275. return self.state
  276. def _update_stats(self, key: str, value: int = 1):
  277. """更新统计"""
  278. with self.stats_lock:
  279. if key in self.stats:
  280. self.stats[key] += value
  281. def _print_stats(self):
  282. """打印统计"""
  283. with self.stats_lock:
  284. if self.stats['start_time']:
  285. elapsed = time.time() - self.stats['start_time']
  286. print("\n=== 安全检测统计 ===")
  287. print(f"运行时长: {elapsed:.1f}秒")
  288. print(f"处理帧数: {self.stats['frames_processed']}")
  289. print(f"检测人员: {self.stats['persons_detected']}次")
  290. print(f"违规检测: {self.stats['violations_detected']}次")
  291. print(f"事件推送: {self.stats['events_pushed']}次")
  292. if self.event_pusher:
  293. push_stats = self.event_pusher.get_stats()
  294. print(f"推送详情: 成功{push_stats['pushed_events']}, 失败{push_stats['failed_events']}")
  295. print("===================\n")
  296. def get_stats(self) -> Dict:
  297. """获取统计"""
  298. with self.stats_lock:
  299. return self.stats.copy()
  300. def get_alerts(self) -> List[AlertRecord]:
  301. """获取告警记录"""
  302. return self.alert_records.copy()
  303. def force_detect(self, frame: np.ndarray = None) -> Tuple[List[SafetyDetection], List[PersonSafetyStatus]]:
  304. """
  305. 强制执行一次检测
  306. Args:
  307. frame: 输入帧,如果为 None 则从摄像头获取
  308. Returns:
  309. (检测结果, 安全状态列表)
  310. """
  311. if frame is None:
  312. frame = self.camera.get_frame() if self.camera else None
  313. if frame is None or self.detector is None:
  314. return [], []
  315. detections = self.detector.detect(frame)
  316. status_list = self.detector.check_safety(frame, detections)
  317. return detections, status_list
  318. class SimpleCamera:
  319. """简单摄像头封装(用于测试)"""
  320. def __init__(self, source=0):
  321. """
  322. 初始化摄像头
  323. Args:
  324. source: 视频源 (摄像头索引、RTSP地址、视频文件路径)
  325. """
  326. self.cap = None
  327. self.source = source
  328. self.connected = False
  329. def connect(self) -> bool:
  330. """连接摄像头"""
  331. try:
  332. # 使用 FFmpeg 单线程模式避免线程安全崩溃
  333. import os
  334. os.environ['OPENCV_FFMPEG_CAPTURE_OPTIONS'] = 'threads;1'
  335. self.cap = cv2.VideoCapture(self.source, cv2.CAP_FFMPEG)
  336. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
  337. self.connected = self.cap.isOpened()
  338. return self.connected
  339. except Exception as e:
  340. print(f"连接摄像头失败: {e}")
  341. return False
  342. def disconnect(self):
  343. """断开连接"""
  344. if self.cap:
  345. self.cap.release()
  346. self.connected = False
  347. def get_frame(self) -> Optional[np.ndarray]:
  348. """获取帧"""
  349. if self.cap is None or not self.cap.isOpened():
  350. return None
  351. ret, frame = self.cap.read()
  352. return frame if ret else None
  353. def create_coordinator(camera_source=0, config: Dict = None) -> SafetyCoordinator:
  354. """
  355. 创建安全联动控制器
  356. Args:
  357. camera_source: 摄像头源
  358. config: 配置
  359. Returns:
  360. SafetyCoordinator 实例
  361. """
  362. camera = SimpleCamera(camera_source)
  363. return SafetyCoordinator(camera, config)