safety_coordinator.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. """
  2. 安全联动控制器
  3. 整合安全检测、事件推送和语音播报功能
  4. """
  5. import time
  6. import threading
  7. import queue
  8. from typing import Optional, List, Dict, Tuple, Callable
  9. from dataclasses import dataclass
  10. from enum import Enum
  11. import numpy as np
  12. import cv2
  13. from config import (
  14. COORDINATOR_CONFIG,
  15. SAFETY_DETECTION_CONFIG,
  16. EVENT_PUSHER_CONFIG,
  17. EVENT_LISTENER_CONFIG,
  18. VOICE_ANNOUNCER_CONFIG,
  19. SYSTEM_CONFIG
  20. )
  21. from safety_detector import (
  22. SafetyDetector, SafetyDetection, PersonSafetyStatus,
  23. SafetyViolationType, draw_safety_result
  24. )
  25. from event_pusher import EventPusher, EventListener, SafetyEvent, EventType
  26. from voice_announcer import VoiceAnnouncer, VoicePriority
  27. class CoordinatorState(Enum):
  28. """控制器状态"""
  29. IDLE = 0 # 空闲
  30. DETECTING = 1 # 检测中
  31. TRACKING = 2 # 跟踪中
  32. ALERTING = 3 # 告警中
  33. @dataclass
  34. class AlertRecord:
  35. """告警记录"""
  36. track_id: int # 跟踪ID
  37. violation_type: str # 违规类型
  38. description: str # 描述
  39. frame: Optional[np.ndarray] # 图像
  40. timestamp: float # 时间戳
  41. pushed: bool = False # 是否已推送
  42. announced: bool = False # 是否已播报
  43. class SafetyCoordinator:
  44. """
  45. 安全联动控制器
  46. 协调摄像头、安全检测、事件推送和语音播报
  47. """
  48. def __init__(self, camera, config: Dict = None):
  49. """
  50. 初始化安全联动控制器
  51. Args:
  52. camera: 摄像头实例 (支持 get_frame() 方法)
  53. config: 配置覆盖
  54. """
  55. self.camera = camera
  56. self.config = config or {}
  57. # 安全检测器
  58. self.detector = None
  59. # 事件推送器
  60. self.event_pusher = None
  61. # 语音播报器
  62. self.voice_announcer = None
  63. # 事件监听器
  64. self.event_listener = None
  65. # 状态
  66. self.state = CoordinatorState.IDLE
  67. self.state_lock = threading.Lock()
  68. # 运行标志
  69. self.running = False
  70. self.worker_thread = None
  71. # 跟踪状态
  72. self.tracks = {} # track_id -> {'center': (x,y), 'last_update': time, 'alerts': [...]}
  73. self.next_track_id = 1
  74. # 告警记录
  75. self.alert_records: List[AlertRecord] = []
  76. self.alert_cooldown = {} # track_id -> last_alert_time
  77. # 统计
  78. self.stats = {
  79. 'frames_processed': 0,
  80. 'persons_detected': 0,
  81. 'violations_detected': 0,
  82. 'events_pushed': 0,
  83. 'voice_announced': 0,
  84. 'start_time': None
  85. }
  86. self.stats_lock = threading.Lock()
  87. # 回调
  88. self.on_violation_detected: Optional[Callable] = None
  89. self.on_frame_processed: Optional[Callable] = None
  90. # 初始化组件
  91. self._init_components()
  92. def _init_components(self):
  93. """初始化各组件"""
  94. # 从 SYSTEM_CONFIG 读取功能开关
  95. enable_detection = SYSTEM_CONFIG.get('enable_detection', True)
  96. enable_safety_detection = SYSTEM_CONFIG.get('enable_safety_detection', True)
  97. enable_event_push = SYSTEM_CONFIG.get('enable_event_push', True)
  98. enable_voice_announce = SYSTEM_CONFIG.get('enable_voice_announce', True)
  99. # 安全检测器
  100. if enable_detection and enable_safety_detection:
  101. try:
  102. self.detector = SafetyDetector(
  103. model_path=SAFETY_DETECTION_CONFIG.get('model_path'),
  104. use_gpu=SAFETY_DETECTION_CONFIG.get('use_gpu', True),
  105. conf_threshold=SAFETY_DETECTION_CONFIG.get('conf_threshold', 0.5),
  106. person_threshold=SAFETY_DETECTION_CONFIG.get('person_threshold', 0.8)
  107. )
  108. print("安全检测器初始化成功")
  109. except Exception as e:
  110. print(f"安全检测器初始化失败: {e}")
  111. else:
  112. print("安全检测功能已禁用")
  113. # 事件推送器
  114. if enable_event_push:
  115. try:
  116. self.event_pusher = EventPusher(EVENT_PUSHER_CONFIG)
  117. print("事件推送器初始化成功")
  118. except Exception as e:
  119. print(f"事件推送器初始化失败: {e}")
  120. else:
  121. print("事件推送功能已禁用")
  122. # 语音播报器
  123. if enable_voice_announce:
  124. try:
  125. self.voice_announcer = VoiceAnnouncer(
  126. tts_config=VOICE_ANNOUNCER_CONFIG.get('tts', {}),
  127. player_config=VOICE_ANNOUNCER_CONFIG.get('player', {})
  128. )
  129. print("语音播报器初始化成功")
  130. except Exception as e:
  131. print(f"语音播报器初始化失败: {e}")
  132. else:
  133. print("语音播报功能已禁用")
  134. # 事件监听器
  135. if EVENT_LISTENER_CONFIG.get('enabled', True):
  136. try:
  137. self.event_listener = EventListener(EVENT_LISTENER_CONFIG)
  138. # 设置语音播放回调
  139. self.event_listener.set_voice_callback(self._on_voice_command)
  140. print("事件监听器初始化成功")
  141. except Exception as e:
  142. print(f"事件监听器初始化失败: {e}")
  143. def _on_voice_command(self, cmd: Dict):
  144. """处理语音播放指令"""
  145. if not self.voice_announcer:
  146. return
  147. text = cmd.get('text', '')
  148. priority = VoicePriority(cmd.get('priority', 2))
  149. if text:
  150. self.voice_announcer.announce(text, priority=priority)
  151. def start(self) -> bool:
  152. """启动控制器"""
  153. if self.running:
  154. return True
  155. # 启动各组件
  156. if self.event_pusher:
  157. self.event_pusher.start()
  158. if self.voice_announcer:
  159. self.voice_announcer.start()
  160. if self.event_listener:
  161. self.event_listener.start()
  162. # 启动工作线程
  163. self.running = True
  164. self.worker_thread = threading.Thread(target=self._worker, daemon=True)
  165. self.worker_thread.start()
  166. with self.stats_lock:
  167. self.stats['start_time'] = time.time()
  168. print("安全联动控制器已启动")
  169. return True
  170. def stop(self):
  171. """停止控制器"""
  172. self.running = False
  173. if self.worker_thread:
  174. self.worker_thread.join(timeout=3)
  175. if self.event_pusher:
  176. self.event_pusher.stop()
  177. if self.voice_announcer:
  178. self.voice_announcer.stop()
  179. if self.event_listener:
  180. self.event_listener.stop()
  181. self._print_stats()
  182. print("安全联动控制器已停止")
  183. def _worker(self):
  184. """工作线程"""
  185. detection_interval = SAFETY_DETECTION_CONFIG.get('detection_interval', 0.1)
  186. last_detection_time = 0
  187. while self.running:
  188. try:
  189. current_time = time.time()
  190. # 获取帧
  191. frame = self.camera.get_frame() if self.camera else None
  192. if frame is None:
  193. time.sleep(0.01)
  194. continue
  195. self._update_stats('frames_processed')
  196. # 周期性检测
  197. if current_time - last_detection_time >= detection_interval:
  198. last_detection_time = current_time
  199. self._process_frame(frame)
  200. # 清理过期跟踪
  201. self._cleanup_tracks()
  202. time.sleep(0.01)
  203. except Exception as e:
  204. print(f"处理错误: {e}")
  205. time.sleep(0.1)
  206. def _process_frame(self, frame: np.ndarray):
  207. """处理帧"""
  208. if self.detector is None:
  209. return
  210. self._set_state(CoordinatorState.DETECTING)
  211. # 安全检测
  212. detections = self.detector.detect(frame)
  213. status_list = self.detector.check_safety(frame, detections)
  214. self._update_stats('persons_detected', len(status_list))
  215. # 更新跟踪
  216. self._update_tracks(detections)
  217. # 检查违规
  218. for status in status_list:
  219. if status.is_violation:
  220. self._handle_violation(status, frame)
  221. # 回调
  222. if self.on_frame_processed:
  223. self.on_frame_processed(frame, detections, status_list)
  224. def _update_tracks(self, detections: List[SafetyDetection]):
  225. """更新跟踪状态"""
  226. current_time = time.time()
  227. persons = [d for d in detections if d.class_id == 3] # 人
  228. # 匹配现有跟踪
  229. used_ids = set()
  230. for person in persons:
  231. best_id = None
  232. min_dist = float('inf')
  233. for track_id, track in self.tracks.items():
  234. if track_id in used_ids:
  235. continue
  236. dist = np.sqrt(
  237. (person.center[0] - track['center'][0])**2 +
  238. (person.center[1] - track['center'][1])**2
  239. )
  240. if dist < min_dist and dist < 100: # 距离阈值
  241. min_dist = dist
  242. best_id = track_id
  243. if best_id is not None:
  244. # 更新现有跟踪
  245. self.tracks[best_id]['center'] = person.center
  246. self.tracks[best_id]['last_update'] = current_time
  247. person.track_id = best_id
  248. used_ids.add(best_id)
  249. else:
  250. # 新跟踪
  251. track_id = self.next_track_id
  252. self.next_track_id += 1
  253. person.track_id = track_id
  254. self.tracks[track_id] = {
  255. 'center': person.center,
  256. 'last_update': current_time,
  257. 'alerts': []
  258. }
  259. def _cleanup_tracks(self):
  260. """清理过期跟踪"""
  261. current_time = time.time()
  262. timeout = COORDINATOR_CONFIG.get('tracking_timeout', 5.0)
  263. expired = [
  264. tid for tid, t in self.tracks.items()
  265. if current_time - t['last_update'] > timeout
  266. ]
  267. for tid in expired:
  268. del self.tracks[tid]
  269. self.alert_cooldown.pop(tid, None)
  270. def _handle_violation(self, status: PersonSafetyStatus, frame: np.ndarray):
  271. """处理违规"""
  272. current_time = time.time()
  273. track_id = status.track_id
  274. # 检查冷却时间
  275. cooldown = SAFETY_DETECTION_CONFIG.get('alert_cooldown', 3.0)
  276. if track_id in self.alert_cooldown:
  277. if current_time - self.alert_cooldown[track_id] < cooldown:
  278. return
  279. # 记录告警
  280. self.alert_cooldown[track_id] = current_time
  281. description = status.get_violation_desc()
  282. violation_type = status.violation_types[0].value if status.violation_types else "未知"
  283. # 裁剪人体区域
  284. x1, y1, x2, y2 = status.person_bbox
  285. margin = 20
  286. x1 = max(0, x1 - margin)
  287. y1 = max(0, y1 - margin)
  288. x2 = min(frame.shape[1], x2 + margin)
  289. y2 = min(frame.shape[0], y2 + margin)
  290. person_image = frame[y1:y2, x1:x2].copy()
  291. record = AlertRecord(
  292. track_id=track_id,
  293. violation_type=violation_type,
  294. description=description,
  295. frame=person_image,
  296. timestamp=current_time
  297. )
  298. self.alert_records.append(record)
  299. self._update_stats('violations_detected')
  300. # 回调
  301. if self.on_violation_detected:
  302. self.on_violation_detected(status, frame)
  303. # 推送事件
  304. if self.event_pusher:
  305. self.event_pusher.push_safety_violation(
  306. description=description,
  307. image=person_image,
  308. track_id=track_id,
  309. confidence=status.person_conf
  310. )
  311. self._update_stats('events_pushed')
  312. # 语音播报
  313. if self.voice_announcer:
  314. self.voice_announcer.announce_violation(description, urgent=True)
  315. self._update_stats('voice_announced')
  316. print(f"[告警] {description}, 跟踪ID: {track_id}")
  317. def _set_state(self, state: CoordinatorState):
  318. """设置状态"""
  319. with self.state_lock:
  320. self.state = state
  321. def get_state(self) -> CoordinatorState:
  322. """获取状态"""
  323. with self.state_lock:
  324. return self.state
  325. def _update_stats(self, key: str, value: int = 1):
  326. """更新统计"""
  327. with self.stats_lock:
  328. if key in self.stats:
  329. self.stats[key] += value
  330. def _print_stats(self):
  331. """打印统计"""
  332. with self.stats_lock:
  333. if self.stats['start_time']:
  334. elapsed = time.time() - self.stats['start_time']
  335. print("\n=== 安全检测统计 ===")
  336. print(f"运行时长: {elapsed:.1f}秒")
  337. print(f"处理帧数: {self.stats['frames_processed']}")
  338. print(f"检测人员: {self.stats['persons_detected']}次")
  339. print(f"违规检测: {self.stats['violations_detected']}次")
  340. print(f"事件推送: {self.stats['events_pushed']}次")
  341. print(f"语音播报: {self.stats['voice_announced']}次")
  342. if self.event_pusher:
  343. push_stats = self.event_pusher.get_stats()
  344. print(f"推送详情: 成功{push_stats['pushed_events']}, 失败{push_stats['failed_events']}")
  345. if self.voice_announcer:
  346. voice_stats = self.voice_announcer.get_stats()
  347. print(f"播报详情: 成功{voice_stats['played_commands']}, 失败{voice_stats['failed_commands']}")
  348. print("===================\n")
  349. def get_stats(self) -> Dict:
  350. """获取统计"""
  351. with self.stats_lock:
  352. return self.stats.copy()
  353. def get_alerts(self) -> List[AlertRecord]:
  354. """获取告警记录"""
  355. return self.alert_records.copy()
  356. def announce(self, text: str, priority: VoicePriority = VoicePriority.NORMAL):
  357. """
  358. 手动播报语音
  359. Args:
  360. text: 播报文本
  361. priority: 优先级
  362. """
  363. if self.voice_announcer:
  364. self.voice_announcer.announce(text, priority=priority)
  365. def force_detect(self, frame: np.ndarray = None) -> Tuple[List[SafetyDetection], List[PersonSafetyStatus]]:
  366. """
  367. 强制执行一次检测
  368. Args:
  369. frame: 输入帧,如果为 None 则从摄像头获取
  370. Returns:
  371. (检测结果, 安全状态列表)
  372. """
  373. if frame is None:
  374. frame = self.camera.get_frame() if self.camera else None
  375. if frame is None or self.detector is None:
  376. return [], []
  377. detections = self.detector.detect(frame)
  378. status_list = self.detector.check_safety(frame, detections)
  379. return detections, status_list
  380. class SimpleCamera:
  381. """简单摄像头封装(用于测试)"""
  382. def __init__(self, source=0):
  383. """
  384. 初始化摄像头
  385. Args:
  386. source: 视频源 (摄像头索引、RTSP地址、视频文件路径)
  387. """
  388. self.cap = None
  389. self.source = source
  390. self.connected = False
  391. def connect(self) -> bool:
  392. """连接摄像头"""
  393. try:
  394. self.cap = cv2.VideoCapture(self.source)
  395. self.connected = self.cap.isOpened()
  396. return self.connected
  397. except Exception as e:
  398. print(f"连接摄像头失败: {e}")
  399. return False
  400. def disconnect(self):
  401. """断开连接"""
  402. if self.cap:
  403. self.cap.release()
  404. self.connected = False
  405. def get_frame(self) -> Optional[np.ndarray]:
  406. """获取帧"""
  407. if self.cap is None or not self.cap.isOpened():
  408. return None
  409. ret, frame = self.cap.read()
  410. return frame if ret else None
  411. def create_coordinator(camera_source=0, config: Dict = None) -> SafetyCoordinator:
  412. """
  413. 创建安全联动控制器
  414. Args:
  415. camera_source: 摄像头源
  416. config: 配置
  417. Returns:
  418. SafetyCoordinator 实例
  419. """
  420. camera = SimpleCamera(camera_source)
  421. return SafetyCoordinator(camera, config)