safety_coordinator.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593
  1. """
  2. 安全联动控制器
  3. 整合安全检测、事件推送和语音播报功能
  4. """
  5. import time
  6. import threading
  7. import queue
  8. from typing import Optional, List, Dict, Tuple, Callable
  9. from dataclasses import dataclass
  10. from enum import Enum
  11. import numpy as np
  12. import cv2
  13. from config import (
  14. COORDINATOR_CONFIG,
  15. SAFETY_DETECTION_CONFIG,
  16. EVENT_PUSHER_CONFIG,
  17. EVENT_LISTENER_CONFIG,
  18. VOICE_ANNOUNCER_CONFIG,
  19. SYSTEM_CONFIG
  20. )
  21. from safety_detector import (
  22. SafetyDetector, SafetyDetection, PersonSafetyStatus,
  23. SafetyViolationType, draw_safety_result
  24. )
  25. from event_pusher import EventPusher, EventListener, SafetyEvent, EventType
  26. from voice_announcer import VoiceAnnouncer, VoicePriority
  27. class CoordinatorState(Enum):
  28. """控制器状态"""
  29. IDLE = 0 # 空闲
  30. DETECTING = 1 # 检测中
  31. TRACKING = 2 # 跟踪中
  32. ALERTING = 3 # 告警中
  33. @dataclass
  34. class AlertRecord:
  35. """告警记录"""
  36. track_id: int # 跟踪ID
  37. violation_type: str # 违规类型
  38. description: str # 描述
  39. frame: Optional[np.ndarray] # 图像
  40. timestamp: float # 时间戳
  41. pushed: bool = False # 是否已推送
  42. announced: bool = False # 是否已播报
  43. class SafetyCoordinator:
  44. """安全联动控制器:协调摄像头、安全检测、事件推送、语音播报、PTZ跟踪"""
  45. def __init__(self, camera, config: Dict = None, ptz_camera=None, calibrator=None):
  46. self.camera = camera
  47. self.config = config or {}
  48. self.ptz = ptz_camera # PTZ球机(可选)
  49. self.calibrator = calibrator # 校准器(可选)
  50. self.detector = None
  51. self.event_pusher = None
  52. self.voice_announcer = None
  53. self.event_listener = None
  54. self.state = CoordinatorState.IDLE
  55. self.state_lock = threading.Lock()
  56. self.running = False
  57. self.worker_thread = None
  58. # PTZ跟踪线程(独立于检测线程)
  59. self._ptz_thread = None
  60. self._ptz_queue: queue.Queue = queue.Queue(maxsize=10)
  61. self._ptz_cooldown = 0.15
  62. self._last_ptz_time = 0.0
  63. # 跟踪状态
  64. self.tracks = {}
  65. self.next_track_id = 1
  66. self.alert_records: List[AlertRecord] = []
  67. self.alert_cooldown = {}
  68. self.stats = {
  69. 'frames_processed': 0,
  70. 'persons_detected': 0,
  71. 'violations_detected': 0,
  72. 'events_pushed': 0,
  73. 'voice_announced': 0,
  74. 'ptz_commands_sent': 0,
  75. 'start_time': None
  76. }
  77. self.stats_lock = threading.Lock()
  78. self.on_violation_detected: Optional[Callable] = None
  79. self.on_frame_processed: Optional[Callable] = None
  80. self._init_components()
  81. def _init_components(self):
  82. """初始化各组件"""
  83. # 从 SYSTEM_CONFIG 读取功能开关
  84. enable_detection = SYSTEM_CONFIG.get('enable_detection', True)
  85. enable_safety_detection = SYSTEM_CONFIG.get('enable_safety_detection', True)
  86. enable_event_push = SYSTEM_CONFIG.get('enable_event_push', True)
  87. enable_voice_announce = SYSTEM_CONFIG.get('enable_voice_announce', True)
  88. # 安全检测器
  89. if enable_detection and enable_safety_detection:
  90. try:
  91. self.detector = SafetyDetector(
  92. model_path=SAFETY_DETECTION_CONFIG.get('model_path'),
  93. use_gpu=SAFETY_DETECTION_CONFIG.get('use_gpu', True),
  94. conf_threshold=SAFETY_DETECTION_CONFIG.get('conf_threshold', 0.5),
  95. person_threshold=SAFETY_DETECTION_CONFIG.get('person_threshold', 0.8)
  96. )
  97. print("安全检测器初始化成功")
  98. except Exception as e:
  99. print(f"安全检测器初始化失败: {e}")
  100. else:
  101. print("安全检测功能已禁用")
  102. # 事件推送器
  103. if enable_event_push:
  104. try:
  105. self.event_pusher = EventPusher(EVENT_PUSHER_CONFIG)
  106. print("事件推送器初始化成功")
  107. except Exception as e:
  108. print(f"事件推送器初始化失败: {e}")
  109. else:
  110. print("事件推送功能已禁用")
  111. # 语音播报器
  112. if enable_voice_announce:
  113. try:
  114. self.voice_announcer = VoiceAnnouncer(
  115. tts_config=VOICE_ANNOUNCER_CONFIG.get('tts', {}),
  116. player_config=VOICE_ANNOUNCER_CONFIG.get('player', {})
  117. )
  118. print("语音播报器初始化成功")
  119. except Exception as e:
  120. print(f"语音播报器初始化失败: {e}")
  121. else:
  122. print("语音播报功能已禁用")
  123. # 事件监听器
  124. if EVENT_LISTENER_CONFIG.get('enabled', True):
  125. try:
  126. self.event_listener = EventListener(EVENT_LISTENER_CONFIG)
  127. # 设置语音播放回调
  128. self.event_listener.set_voice_callback(self._on_voice_command)
  129. print("事件监听器初始化成功")
  130. except Exception as e:
  131. print(f"事件监听器初始化失败: {e}")
  132. def _on_voice_command(self, cmd: Dict):
  133. """处理语音播放指令"""
  134. if not self.voice_announcer:
  135. return
  136. text = cmd.get('text', '')
  137. priority = VoicePriority(cmd.get('priority', 2))
  138. if text:
  139. self.voice_announcer.announce(text, priority=priority)
  140. def start(self) -> bool:
  141. """启动控制器"""
  142. if self.running:
  143. return True
  144. if self.event_pusher:
  145. self.event_pusher.start()
  146. if self.voice_announcer:
  147. self.voice_announcer.start()
  148. if self.event_listener:
  149. self.event_listener.start()
  150. self.running = True
  151. self.worker_thread = threading.Thread(target=self._worker, daemon=True)
  152. self.worker_thread.start()
  153. # 启动 PTZ 跟踪线程(如果 PTZ 可用)
  154. if self.ptz and SYSTEM_CONFIG.get('enable_ptz_tracking', True):
  155. self._ptz_thread = threading.Thread(target=self._ptz_worker, daemon=True)
  156. self._ptz_thread.start()
  157. print("[SafetyCoordinator] PTZ跟踪线程已启动")
  158. with self.stats_lock:
  159. self.stats['start_time'] = time.time()
  160. print("安全联动控制器已启动")
  161. return True
  162. def stop(self):
  163. """停止控制器"""
  164. self.running = False
  165. if self.worker_thread:
  166. self.worker_thread.join(timeout=3)
  167. # 停止 PTZ 跟踪线程
  168. if self._ptz_thread:
  169. self._ptz_thread.join(timeout=2)
  170. self._ptz_thread = None
  171. if self.event_pusher:
  172. self.event_pusher.stop()
  173. if self.voice_announcer:
  174. self.voice_announcer.stop()
  175. if self.event_listener:
  176. self.event_listener.stop()
  177. self._print_stats()
  178. print("安全联动控制器已停止")
  179. def _worker(self):
  180. """工作线程"""
  181. detection_interval = SAFETY_DETECTION_CONFIG.get('detection_interval', 0.1)
  182. last_detection_time = 0
  183. while self.running:
  184. try:
  185. current_time = time.time()
  186. # 获取帧
  187. frame = self.camera.get_frame() if self.camera else None
  188. if frame is None:
  189. time.sleep(0.01)
  190. continue
  191. self._update_stats('frames_processed')
  192. # 周期性检测
  193. if current_time - last_detection_time >= detection_interval:
  194. last_detection_time = current_time
  195. self._process_frame(frame)
  196. # 清理过期跟踪
  197. self._cleanup_tracks()
  198. time.sleep(0.01)
  199. except Exception as e:
  200. print(f"处理错误: {e}")
  201. time.sleep(0.1)
  202. def _process_frame(self, frame: np.ndarray):
  203. """处理帧"""
  204. if self.detector is None:
  205. return
  206. self._set_state(CoordinatorState.DETECTING)
  207. # 安全检测
  208. detections = self.detector.detect(frame)
  209. status_list = self.detector.check_safety(frame, detections)
  210. self._update_stats('persons_detected', len(status_list))
  211. # 更新跟踪
  212. self._update_tracks(detections)
  213. # 检查违规
  214. for status in status_list:
  215. if status.is_violation:
  216. self._handle_violation(status, frame)
  217. # 回调
  218. if self.on_frame_processed:
  219. self.on_frame_processed(frame, detections, status_list)
  220. def _update_tracks(self, detections: List[SafetyDetection]):
  221. """更新跟踪状态"""
  222. current_time = time.time()
  223. persons = [d for d in detections if d.class_id == 3] # 人
  224. # 匹配现有跟踪
  225. used_ids = set()
  226. for person in persons:
  227. best_id = None
  228. min_dist = float('inf')
  229. for track_id, track in self.tracks.items():
  230. if track_id in used_ids:
  231. continue
  232. dist = np.sqrt(
  233. (person.center[0] - track['center'][0])**2 +
  234. (person.center[1] - track['center'][1])**2
  235. )
  236. if dist < min_dist and dist < 100: # 距离阈值
  237. min_dist = dist
  238. best_id = track_id
  239. if best_id is not None:
  240. # 更新现有跟踪
  241. self.tracks[best_id]['center'] = person.center
  242. self.tracks[best_id]['last_update'] = current_time
  243. person.track_id = best_id
  244. used_ids.add(best_id)
  245. else:
  246. # 新跟踪
  247. track_id = self.next_track_id
  248. self.next_track_id += 1
  249. person.track_id = track_id
  250. self.tracks[track_id] = {
  251. 'center': person.center,
  252. 'last_update': current_time,
  253. 'alerts': []
  254. }
  255. def _cleanup_tracks(self):
  256. """清理过期跟踪"""
  257. current_time = time.time()
  258. timeout = COORDINATOR_CONFIG.get('tracking_timeout', 5.0)
  259. expired = [
  260. tid for tid, t in self.tracks.items()
  261. if current_time - t['last_update'] > timeout
  262. ]
  263. for tid in expired:
  264. del self.tracks[tid]
  265. self.alert_cooldown.pop(tid, None)
  266. def _handle_violation(self, status: PersonSafetyStatus, frame: np.ndarray):
  267. """处理违规"""
  268. current_time = time.time()
  269. track_id = status.track_id
  270. # 检查冷却时间
  271. cooldown = SAFETY_DETECTION_CONFIG.get('alert_cooldown', 3.0)
  272. if track_id in self.alert_cooldown:
  273. if current_time - self.alert_cooldown[track_id] < cooldown:
  274. return
  275. # 记录告警
  276. self.alert_cooldown[track_id] = current_time
  277. description = status.get_violation_desc()
  278. violation_type = status.violation_types[0].value if status.violation_types else "未知"
  279. # 裁剪人体区域
  280. x1, y1, x2, y2 = status.person_bbox
  281. margin = 20
  282. x1 = max(0, x1 - margin)
  283. y1 = max(0, y1 - margin)
  284. x2 = min(frame.shape[1], x2 + margin)
  285. y2 = min(frame.shape[0], y2 + margin)
  286. person_image = frame[y1:y2, x1:x2].copy()
  287. record = AlertRecord(
  288. track_id=track_id,
  289. violation_type=violation_type,
  290. description=description,
  291. frame=person_image,
  292. timestamp=current_time
  293. )
  294. self.alert_records.append(record)
  295. self._update_stats('violations_detected')
  296. # PTZ 跟踪违规人员(如果 PTZ 可用且启用)
  297. if self.ptz and SYSTEM_CONFIG.get('enable_ptz_tracking', True):
  298. self._track_violator_ptz(status, frame)
  299. # 回调
  300. if self.on_violation_detected:
  301. self.on_violation_detected(status, frame)
  302. # 推送事件
  303. if self.event_pusher:
  304. self.event_pusher.push_safety_violation(
  305. description=description,
  306. image=person_image,
  307. track_id=track_id,
  308. confidence=status.person_conf
  309. )
  310. self._update_stats('events_pushed')
  311. # 语音播报
  312. if self.voice_announcer:
  313. self.voice_announcer.announce_violation(description, urgent=True)
  314. self._update_stats('voice_announced')
  315. print(f"[告警] {description}, 跟踪ID: {track_id}")
  316. def _track_violator_ptz(self, status: PersonSafetyStatus, frame: np.ndarray):
  317. """违规人员PTZ跟踪:将违规人员在全景画面中的位置发送给PTZ线程"""
  318. if self.ptz is None:
  319. return
  320. frame_h, frame_w = frame.shape[:2]
  321. x1, y1, x2, y2 = status.person_bbox
  322. # 计算违规人员在全景画面中的相对位置
  323. center_x = (x1 + x2) / 2
  324. center_y = (y1 + y2) / 2
  325. x_ratio = center_x / frame_w
  326. y_ratio = center_y / frame_h
  327. # 冷却检查
  328. current_time = time.time()
  329. if current_time - self._last_ptz_time < self._ptz_cooldown:
  330. return
  331. # 发送PTZ命令
  332. try:
  333. self._ptz_queue.put_nowait({
  334. 'x_ratio': x_ratio,
  335. 'y_ratio': y_ratio,
  336. 'track_id': status.track_id,
  337. 'violation_type': status.violation_types[0].value if status.violation_types else 'unknown'
  338. })
  339. self._last_ptz_time = current_time
  340. self._update_stats('ptz_commands_sent')
  341. except queue.Full:
  342. pass # 队列满则丢弃,下一个检测周期会重发
  343. def _ptz_worker(self):
  344. """PTZ控制工作线程:独立处理所有PTZ命令"""
  345. while self.running:
  346. try:
  347. try:
  348. cmd = self._ptz_queue.get(timeout=0.1)
  349. except queue.Empty:
  350. continue
  351. if self.ptz is None:
  352. continue
  353. x_ratio = cmd['x_ratio']
  354. y_ratio = cmd['y_ratio']
  355. # 使用校准器转换坐标,或使用估算
  356. if self.calibrator and self.calibrator.is_calibrated():
  357. pan, tilt = self.calibrator.transform(x_ratio, y_ratio)
  358. zoom = self.ptz.ptz_config.get('default_zoom', 8)
  359. if self.ptz.ptz_config.get('pan_flip', False):
  360. pan = (pan + 180) % 360
  361. self.ptz.goto_exact_position(pan, tilt, zoom)
  362. else:
  363. self.ptz.track_target(x_ratio, y_ratio)
  364. except Exception as e:
  365. print(f"[SafetyCoordinator] PTZ跟踪错误: {e}")
  366. time.sleep(0.05)
  367. def _set_state(self, state: CoordinatorState):
  368. """设置状态"""
  369. with self.state_lock:
  370. self.state = state
  371. def get_state(self) -> CoordinatorState:
  372. """获取状态"""
  373. with self.state_lock:
  374. return self.state
  375. def _update_stats(self, key: str, value: int = 1):
  376. """更新统计"""
  377. with self.stats_lock:
  378. if key in self.stats:
  379. self.stats[key] += value
  380. def _print_stats(self):
  381. """打印统计"""
  382. with self.stats_lock:
  383. if self.stats['start_time']:
  384. elapsed = time.time() - self.stats['start_time']
  385. print("\n=== 安全检测统计 ===")
  386. print(f"运行时长: {elapsed:.1f}秒")
  387. print(f"处理帧数: {self.stats['frames_processed']}")
  388. print(f"检测人员: {self.stats['persons_detected']}次")
  389. print(f"违规检测: {self.stats['violations_detected']}次")
  390. print(f"事件推送: {self.stats['events_pushed']}次")
  391. print(f"语音播报: {self.stats['voice_announced']}次")
  392. if self.event_pusher:
  393. push_stats = self.event_pusher.get_stats()
  394. print(f"推送详情: 成功{push_stats['pushed_events']}, 失败{push_stats['failed_events']}")
  395. if self.voice_announcer:
  396. voice_stats = self.voice_announcer.get_stats()
  397. print(f"播报详情: 成功{voice_stats['played_commands']}, 失败{voice_stats['failed_commands']}")
  398. print("===================\n")
  399. def get_stats(self) -> Dict:
  400. """获取统计"""
  401. with self.stats_lock:
  402. return self.stats.copy()
  403. def get_alerts(self) -> List[AlertRecord]:
  404. """获取告警记录"""
  405. return self.alert_records.copy()
  406. def announce(self, text: str, priority: VoicePriority = VoicePriority.NORMAL):
  407. """
  408. 手动播报语音
  409. Args:
  410. text: 播报文本
  411. priority: 优先级
  412. """
  413. if self.voice_announcer:
  414. self.voice_announcer.announce(text, priority=priority)
  415. def force_detect(self, frame: np.ndarray = None) -> Tuple[List[SafetyDetection], List[PersonSafetyStatus]]:
  416. """
  417. 强制执行一次检测
  418. Args:
  419. frame: 输入帧,如果为 None 则从摄像头获取
  420. Returns:
  421. (检测结果, 安全状态列表)
  422. """
  423. if frame is None:
  424. frame = self.camera.get_frame() if self.camera else None
  425. if frame is None or self.detector is None:
  426. return [], []
  427. detections = self.detector.detect(frame)
  428. status_list = self.detector.check_safety(frame, detections)
  429. return detections, status_list
  430. class SimpleCamera:
  431. """简单摄像头封装(用于测试)"""
  432. def __init__(self, source=0):
  433. """
  434. 初始化摄像头
  435. Args:
  436. source: 视频源 (摄像头索引、RTSP地址、视频文件路径)
  437. """
  438. self.cap = None
  439. self.source = source
  440. self.connected = False
  441. def connect(self) -> bool:
  442. """连接摄像头"""
  443. try:
  444. # 使用 FFmpeg 单线程模式避免线程安全崩溃
  445. import os
  446. os.environ['OPENCV_FFMPEG_CAPTURE_OPTIONS'] = 'threads;1'
  447. self.cap = cv2.VideoCapture(self.source, cv2.CAP_FFMPEG)
  448. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
  449. self.connected = self.cap.isOpened()
  450. return self.connected
  451. except Exception as e:
  452. print(f"连接摄像头失败: {e}")
  453. return False
  454. def disconnect(self):
  455. """断开连接"""
  456. if self.cap:
  457. self.cap.release()
  458. self.connected = False
  459. def get_frame(self) -> Optional[np.ndarray]:
  460. """获取帧"""
  461. if self.cap is None or not self.cap.isOpened():
  462. return None
  463. ret, frame = self.cap.read()
  464. return frame if ret else None
  465. def create_coordinator(camera_source=0, config: Dict = None) -> SafetyCoordinator:
  466. """
  467. 创建安全联动控制器
  468. Args:
  469. camera_source: 摄像头源
  470. config: 配置
  471. Returns:
  472. SafetyCoordinator 实例
  473. """
  474. camera = SimpleCamera(camera_source)
  475. return SafetyCoordinator(camera, config)