safety_coordinator.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. """
  2. 安全联动控制器
  3. 整合安全检测、事件推送和语音播报功能
  4. """
  5. import time
  6. import threading
  7. import queue
  8. from typing import Optional, List, Dict, Tuple, Callable
  9. from dataclasses import dataclass
  10. from enum import Enum
  11. import numpy as np
  12. import cv2
  13. from config import (
  14. COORDINATOR_CONFIG,
  15. SAFETY_DETECTION_CONFIG,
  16. EVENT_PUSHER_CONFIG,
  17. EVENT_LISTENER_CONFIG,
  18. VOICE_ANNOUNCER_CONFIG,
  19. SYSTEM_CONFIG
  20. )
  21. from safety_detector import (
  22. SafetyDetector, SafetyDetection, PersonSafetyStatus,
  23. SafetyViolationType, draw_safety_result
  24. )
  25. from event_pusher import EventPusher, EventListener, SafetyEvent, EventType
  26. from voice_announcer import VoiceAnnouncer, VoicePriority
  27. class CoordinatorState(Enum):
  28. """控制器状态"""
  29. IDLE = 0 # 空闲
  30. DETECTING = 1 # 检测中
  31. TRACKING = 2 # 跟踪中
  32. ALERTING = 3 # 告警中
  33. @dataclass
  34. class AlertRecord:
  35. """告警记录"""
  36. track_id: int # 跟踪ID
  37. violation_type: str # 违规类型
  38. description: str # 描述
  39. frame: Optional[np.ndarray] # 图像
  40. timestamp: float # 时间戳
  41. pushed: bool = False # 是否已推送
  42. announced: bool = False # 是否已播报
  43. class SafetyCoordinator:
  44. """安全联动控制器:协调摄像头、安全检测、事件推送、语音播报、PTZ跟踪"""
  45. def __init__(self, camera, config: Dict = None, ptz_camera=None, calibrator=None):
  46. self.camera = camera
  47. self.config = config or {}
  48. self.ptz = ptz_camera # PTZ球机(可选)
  49. self.calibrator = calibrator # 校准器(可选)
  50. self.detector = None
  51. self.event_pusher = None
  52. self.voice_announcer = None
  53. self.event_listener = None
  54. self.state = CoordinatorState.IDLE
  55. self.state_lock = threading.Lock()
  56. self.running = False
  57. self.worker_thread = None
  58. # PTZ跟踪线程(独立于检测线程)
  59. self._ptz_thread = None
  60. self._ptz_queue: queue.Queue = queue.Queue(maxsize=10)
  61. self._ptz_cooldown = 0.15
  62. self._last_ptz_time = 0.0
  63. # 跟踪状态
  64. self.tracks = {}
  65. self.next_track_id = 1
  66. self.alert_records: List[AlertRecord] = []
  67. self.alert_cooldown = {}
  68. self.stats = {
  69. 'frames_processed': 0,
  70. 'persons_detected': 0,
  71. 'violations_detected': 0,
  72. 'events_pushed': 0,
  73. 'voice_announced': 0,
  74. 'ptz_commands_sent': 0,
  75. 'start_time': None
  76. }
  77. self.stats_lock = threading.Lock()
  78. self.on_violation_detected: Optional[Callable] = None
  79. self.on_frame_processed: Optional[Callable] = None
  80. self._init_components()
  81. def _init_components(self):
  82. """初始化各组件"""
  83. # 从 SYSTEM_CONFIG 读取功能开关
  84. enable_detection = SYSTEM_CONFIG.get('enable_detection', True)
  85. enable_safety_detection = SYSTEM_CONFIG.get('enable_safety_detection', True)
  86. enable_event_push = SYSTEM_CONFIG.get('enable_event_push', True)
  87. enable_voice_announce = SYSTEM_CONFIG.get('enable_voice_announce', True)
  88. # 安全检测器
  89. if enable_detection and enable_safety_detection:
  90. try:
  91. self.detector = SafetyDetector(
  92. model_path=SAFETY_DETECTION_CONFIG.get('model_path'),
  93. use_gpu=SAFETY_DETECTION_CONFIG.get('use_gpu', True),
  94. conf_threshold=SAFETY_DETECTION_CONFIG.get('conf_threshold', 0.5),
  95. person_threshold=SAFETY_DETECTION_CONFIG.get('person_threshold', 0.8)
  96. )
  97. print("安全检测器初始化成功")
  98. except Exception as e:
  99. print(f"安全检测器初始化失败: {e}")
  100. else:
  101. print("安全检测功能已禁用")
  102. # 事件推送器
  103. if enable_event_push:
  104. try:
  105. self.event_pusher = EventPusher(EVENT_PUSHER_CONFIG)
  106. print("事件推送器初始化成功")
  107. except Exception as e:
  108. print(f"事件推送器初始化失败: {e}")
  109. else:
  110. print("事件推送功能已禁用")
  111. # 语音播报器
  112. if enable_voice_announce:
  113. try:
  114. self.voice_announcer = VoiceAnnouncer(
  115. tts_config=VOICE_ANNOUNCER_CONFIG.get('tts', {}),
  116. player_config=VOICE_ANNOUNCER_CONFIG.get('player', {})
  117. )
  118. print("语音播报器初始化成功")
  119. except Exception as e:
  120. print(f"语音播报器初始化失败: {e}")
  121. else:
  122. print("语音播报功能已禁用")
  123. # 事件监听器
  124. if EVENT_LISTENER_CONFIG.get('enabled', True):
  125. try:
  126. self.event_listener = EventListener(EVENT_LISTENER_CONFIG)
  127. # 设置语音播放回调
  128. self.event_listener.set_voice_callback(self._on_voice_command)
  129. print("事件监听器初始化成功")
  130. except Exception as e:
  131. print(f"事件监听器初始化失败: {e}")
  132. def _on_voice_command(self, cmd: Dict):
  133. """处理语音播放指令"""
  134. if not self.voice_announcer:
  135. return
  136. text = cmd.get('text', '')
  137. priority = VoicePriority(cmd.get('priority', 2))
  138. if text:
  139. self.voice_announcer.announce(text, priority=priority)
  140. def start(self) -> bool:
  141. """启动控制器"""
  142. if self.running:
  143. return True
  144. if self.event_pusher:
  145. self.event_pusher.start()
  146. if self.voice_announcer:
  147. self.voice_announcer.start()
  148. if self.event_listener:
  149. self.event_listener.start()
  150. self.running = True
  151. self.worker_thread = threading.Thread(target=self._worker, daemon=True)
  152. self.worker_thread.start()
  153. # 启动 PTZ 跟踪线程(如果 PTZ 可用)
  154. if self.ptz and SYSTEM_CONFIG.get('enable_ptz_tracking', True):
  155. self._ptz_thread = threading.Thread(target=self._ptz_worker, daemon=True)
  156. self._ptz_thread.start()
  157. print("[SafetyCoordinator] PTZ跟踪线程已启动")
  158. with self.stats_lock:
  159. self.stats['start_time'] = time.time()
  160. print("安全联动控制器已启动")
  161. return True
  162. def stop(self):
  163. """停止控制器"""
  164. self.running = False
  165. if self.worker_thread:
  166. self.worker_thread.join(timeout=3)
  167. # 停止 PTZ 跟踪线程
  168. if self._ptz_thread:
  169. self._ptz_thread.join(timeout=2)
  170. self._ptz_thread = None
  171. if self.event_pusher:
  172. self.event_pusher.stop()
  173. if self.voice_announcer:
  174. self.voice_announcer.stop()
  175. if self.event_listener:
  176. self.event_listener.stop()
  177. self._print_stats()
  178. print("安全联动控制器已停止")
  179. def _worker(self):
  180. """工作线程"""
  181. detection_interval = SAFETY_DETECTION_CONFIG.get('detection_interval', 0.1)
  182. last_detection_time = 0
  183. detection_run_count = 0
  184. detection_violation_count = 0
  185. frame_count = 0
  186. last_log_time = time.time()
  187. heartbeat_interval = 30.0
  188. last_no_detect_log_time = 0
  189. import logging
  190. sc_logger = logging.getLogger(__name__)
  191. if self.detector is None:
  192. sc_logger.warning("[安全检测] ⚠️ 安全检测器未初始化! 安全检测不可用")
  193. else:
  194. sc_logger.info(f"[安全检测] ✓ 安全检测器已就绪, 检测间隔={detection_interval}s")
  195. while self.running:
  196. try:
  197. current_time = time.time()
  198. frame = self.camera.get_frame() if self.camera else None
  199. if frame is None:
  200. time.sleep(0.01)
  201. continue
  202. frame_count += 1
  203. self._update_stats('frames_processed')
  204. if current_time - last_log_time >= heartbeat_interval:
  205. stats = self.get_stats()
  206. state_str = self.state.name if hasattr(self.state, 'name') else str(self.state)
  207. sc_logger.info(
  208. f"[安全检测] 状态={state_str}, "
  209. f"检测轮次={detection_run_count}(有人={detection_violation_count}), "
  210. f"帧数={frame_count}"
  211. )
  212. frame_count = 0
  213. last_log_time = current_time
  214. if current_time - last_detection_time >= detection_interval:
  215. last_detection_time = current_time
  216. detection_run_count += 1
  217. result = self._process_frame_with_logging(frame, detection_run_count, detection_violation_count, last_no_detect_log_time, sc_logger)
  218. detection_violation_count = result
  219. self._cleanup_tracks()
  220. time.sleep(0.01)
  221. except Exception as e:
  222. sc_logger.error(f"[安全检测] 处理错误: {e}")
  223. time.sleep(0.1)
  224. def _process_frame_with_logging(self, frame: np.ndarray, run_count: int, violation_count: int, last_no_detect_time: float, sc_logger) -> int:
  225. """处理帧并返回更新的violation_count"""
  226. if self.detector is None:
  227. return violation_count
  228. self._set_state(CoordinatorState.DETECTING)
  229. detections = self.detector.detect(frame)
  230. status_list = self.detector.check_safety(frame, detections)
  231. self._update_stats('persons_detected', len(status_list))
  232. self._update_tracks(detections)
  233. has_violation = False
  234. for status in status_list:
  235. if status.is_violation:
  236. self._handle_violation(status, frame)
  237. has_violation = True
  238. if has_violation:
  239. violation_count += 1
  240. if not status_list:
  241. current_time = time.time()
  242. if current_time - last_no_detect_time >= 30.0:
  243. sc_logger.info(
  244. f"[安全检测] · YOLO检测运行正常, 本轮未检测到人员 "
  245. f"(累计检测{run_count}轮, 违规{violation_count}轮)"
  246. )
  247. if self.on_frame_processed:
  248. self.on_frame_processed(frame, detections, status_list)
  249. return violation_count
  250. def _process_frame(self, frame: np.ndarray):
  251. """处理帧"""
  252. if self.detector is None:
  253. return
  254. self._set_state(CoordinatorState.DETECTING)
  255. # 安全检测
  256. detections = self.detector.detect(frame)
  257. status_list = self.detector.check_safety(frame, detections)
  258. self._update_stats('persons_detected', len(status_list))
  259. # 更新跟踪
  260. self._update_tracks(detections)
  261. # 检查违规
  262. for status in status_list:
  263. if status.is_violation:
  264. self._handle_violation(status, frame)
  265. # 回调
  266. if self.on_frame_processed:
  267. self.on_frame_processed(frame, detections, status_list)
  268. def _update_tracks(self, detections: List[SafetyDetection]):
  269. """更新跟踪状态"""
  270. current_time = time.time()
  271. persons = [d for d in detections if d.class_id == 3] # 人
  272. # 匹配现有跟踪
  273. used_ids = set()
  274. for person in persons:
  275. best_id = None
  276. min_dist = float('inf')
  277. for track_id, track in self.tracks.items():
  278. if track_id in used_ids:
  279. continue
  280. dist = np.sqrt(
  281. (person.center[0] - track['center'][0])**2 +
  282. (person.center[1] - track['center'][1])**2
  283. )
  284. if dist < min_dist and dist < 100: # 距离阈值
  285. min_dist = dist
  286. best_id = track_id
  287. if best_id is not None:
  288. # 更新现有跟踪
  289. self.tracks[best_id]['center'] = person.center
  290. self.tracks[best_id]['last_update'] = current_time
  291. person.track_id = best_id
  292. used_ids.add(best_id)
  293. else:
  294. # 新跟踪
  295. track_id = self.next_track_id
  296. self.next_track_id += 1
  297. person.track_id = track_id
  298. self.tracks[track_id] = {
  299. 'center': person.center,
  300. 'last_update': current_time,
  301. 'alerts': []
  302. }
  303. def _cleanup_tracks(self):
  304. """清理过期跟踪"""
  305. current_time = time.time()
  306. timeout = COORDINATOR_CONFIG.get('tracking_timeout', 5.0)
  307. expired = [
  308. tid for tid, t in self.tracks.items()
  309. if current_time - t['last_update'] > timeout
  310. ]
  311. for tid in expired:
  312. del self.tracks[tid]
  313. self.alert_cooldown.pop(tid, None)
  314. def _handle_violation(self, status: PersonSafetyStatus, frame: np.ndarray):
  315. """处理违规"""
  316. current_time = time.time()
  317. track_id = status.track_id
  318. # 检查冷却时间
  319. cooldown = SAFETY_DETECTION_CONFIG.get('alert_cooldown', 3.0)
  320. if track_id in self.alert_cooldown:
  321. if current_time - self.alert_cooldown[track_id] < cooldown:
  322. return
  323. # 记录告警
  324. self.alert_cooldown[track_id] = current_time
  325. description = status.get_violation_desc()
  326. violation_type = status.violation_types[0].value if status.violation_types else "未知"
  327. # 裁剪人体区域
  328. x1, y1, x2, y2 = status.person_bbox
  329. margin = 20
  330. x1 = max(0, x1 - margin)
  331. y1 = max(0, y1 - margin)
  332. x2 = min(frame.shape[1], x2 + margin)
  333. y2 = min(frame.shape[0], y2 + margin)
  334. person_image = frame[y1:y2, x1:x2].copy()
  335. record = AlertRecord(
  336. track_id=track_id,
  337. violation_type=violation_type,
  338. description=description,
  339. frame=person_image,
  340. timestamp=current_time
  341. )
  342. self.alert_records.append(record)
  343. self._update_stats('violations_detected')
  344. # PTZ 跟踪违规人员(如果 PTZ 可用且启用)
  345. if self.ptz and SYSTEM_CONFIG.get('enable_ptz_tracking', True):
  346. self._track_violator_ptz(status, frame)
  347. # 回调
  348. if self.on_violation_detected:
  349. self.on_violation_detected(status, frame)
  350. # 推送事件
  351. if self.event_pusher:
  352. self.event_pusher.push_safety_violation(
  353. description=description,
  354. image=person_image,
  355. track_id=track_id,
  356. confidence=status.person_conf
  357. )
  358. self._update_stats('events_pushed')
  359. # 语音播报
  360. if self.voice_announcer:
  361. self.voice_announcer.announce_violation(description, urgent=True)
  362. self._update_stats('voice_announced')
  363. print(f"[告警] {description}, 跟踪ID: {track_id}")
  364. def _track_violator_ptz(self, status: PersonSafetyStatus, frame: np.ndarray):
  365. """违规人员PTZ跟踪:将违规人员在全景画面中的位置发送给PTZ线程"""
  366. if self.ptz is None:
  367. return
  368. frame_h, frame_w = frame.shape[:2]
  369. x1, y1, x2, y2 = status.person_bbox
  370. # 计算违规人员在全景画面中的相对位置
  371. center_x = (x1 + x2) / 2
  372. center_y = (y1 + y2) / 2
  373. x_ratio = center_x / frame_w
  374. y_ratio = center_y / frame_h
  375. # 冷却检查
  376. current_time = time.time()
  377. if current_time - self._last_ptz_time < self._ptz_cooldown:
  378. return
  379. # 发送PTZ命令
  380. try:
  381. self._ptz_queue.put_nowait({
  382. 'x_ratio': x_ratio,
  383. 'y_ratio': y_ratio,
  384. 'track_id': status.track_id,
  385. 'violation_type': status.violation_types[0].value if status.violation_types else 'unknown'
  386. })
  387. self._last_ptz_time = current_time
  388. self._update_stats('ptz_commands_sent')
  389. except queue.Full:
  390. pass # 队列满则丢弃,下一个检测周期会重发
  391. def _ptz_worker(self):
  392. """PTZ控制工作线程:独立处理所有PTZ命令"""
  393. while self.running:
  394. try:
  395. try:
  396. cmd = self._ptz_queue.get(timeout=0.1)
  397. except queue.Empty:
  398. continue
  399. if self.ptz is None:
  400. continue
  401. x_ratio = cmd['x_ratio']
  402. y_ratio = cmd['y_ratio']
  403. # 使用校准器转换坐标,或使用估算
  404. if self.calibrator and self.calibrator.is_calibrated():
  405. pan, tilt = self.calibrator.transform(x_ratio, y_ratio)
  406. zoom = self.ptz.ptz_config.get('default_zoom', 8)
  407. if self.ptz.ptz_config.get('pan_flip', False):
  408. pan = (pan + 180) % 360
  409. self.ptz.goto_exact_position(pan, tilt, zoom)
  410. else:
  411. self.ptz.track_target(x_ratio, y_ratio)
  412. except Exception as e:
  413. print(f"[SafetyCoordinator] PTZ跟踪错误: {e}")
  414. time.sleep(0.05)
  415. def _set_state(self, state: CoordinatorState):
  416. """设置状态"""
  417. with self.state_lock:
  418. self.state = state
  419. def get_state(self) -> CoordinatorState:
  420. """获取状态"""
  421. with self.state_lock:
  422. return self.state
  423. def _update_stats(self, key: str, value: int = 1):
  424. """更新统计"""
  425. with self.stats_lock:
  426. if key in self.stats:
  427. self.stats[key] += value
  428. def _print_stats(self):
  429. """打印统计"""
  430. with self.stats_lock:
  431. if self.stats['start_time']:
  432. elapsed = time.time() - self.stats['start_time']
  433. print("\n=== 安全检测统计 ===")
  434. print(f"运行时长: {elapsed:.1f}秒")
  435. print(f"处理帧数: {self.stats['frames_processed']}")
  436. print(f"检测人员: {self.stats['persons_detected']}次")
  437. print(f"违规检测: {self.stats['violations_detected']}次")
  438. print(f"事件推送: {self.stats['events_pushed']}次")
  439. print(f"语音播报: {self.stats['voice_announced']}次")
  440. if self.event_pusher:
  441. push_stats = self.event_pusher.get_stats()
  442. print(f"推送详情: 成功{push_stats['pushed_events']}, 失败{push_stats['failed_events']}")
  443. if self.voice_announcer:
  444. voice_stats = self.voice_announcer.get_stats()
  445. print(f"播报详情: 成功{voice_stats['played_commands']}, 失败{voice_stats['failed_commands']}")
  446. print("===================\n")
  447. def get_stats(self) -> Dict:
  448. """获取统计"""
  449. with self.stats_lock:
  450. return self.stats.copy()
  451. def get_alerts(self) -> List[AlertRecord]:
  452. """获取告警记录"""
  453. return self.alert_records.copy()
  454. def announce(self, text: str, priority: VoicePriority = VoicePriority.NORMAL):
  455. """
  456. 手动播报语音
  457. Args:
  458. text: 播报文本
  459. priority: 优先级
  460. """
  461. if self.voice_announcer:
  462. self.voice_announcer.announce(text, priority=priority)
  463. def force_detect(self, frame: np.ndarray = None) -> Tuple[List[SafetyDetection], List[PersonSafetyStatus]]:
  464. """
  465. 强制执行一次检测
  466. Args:
  467. frame: 输入帧,如果为 None 则从摄像头获取
  468. Returns:
  469. (检测结果, 安全状态列表)
  470. """
  471. if frame is None:
  472. frame = self.camera.get_frame() if self.camera else None
  473. if frame is None or self.detector is None:
  474. return [], []
  475. detections = self.detector.detect(frame)
  476. status_list = self.detector.check_safety(frame, detections)
  477. return detections, status_list
  478. class SimpleCamera:
  479. """简单摄像头封装(用于测试)"""
  480. def __init__(self, source=0):
  481. """
  482. 初始化摄像头
  483. Args:
  484. source: 视频源 (摄像头索引、RTSP地址、视频文件路径)
  485. """
  486. self.cap = None
  487. self.source = source
  488. self.connected = False
  489. def connect(self) -> bool:
  490. """连接摄像头"""
  491. try:
  492. # 使用 FFmpeg 单线程模式避免线程安全崩溃
  493. import os
  494. os.environ['OPENCV_FFMPEG_CAPTURE_OPTIONS'] = 'threads;1'
  495. self.cap = cv2.VideoCapture(self.source, cv2.CAP_FFMPEG)
  496. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
  497. self.connected = self.cap.isOpened()
  498. return self.connected
  499. except Exception as e:
  500. print(f"连接摄像头失败: {e}")
  501. return False
  502. def disconnect(self):
  503. """断开连接"""
  504. if self.cap:
  505. self.cap.release()
  506. self.connected = False
  507. def get_frame(self) -> Optional[np.ndarray]:
  508. """获取帧"""
  509. if self.cap is None or not self.cap.isOpened():
  510. return None
  511. ret, frame = self.cap.read()
  512. return frame if ret else None
  513. def create_coordinator(camera_source=0, config: Dict = None) -> SafetyCoordinator:
  514. """
  515. 创建安全联动控制器
  516. Args:
  517. camera_source: 摄像头源
  518. config: 配置
  519. Returns:
  520. SafetyCoordinator 实例
  521. """
  522. camera = SimpleCamera(camera_source)
  523. return SafetyCoordinator(camera, config)