safety_coordinator.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656
  1. """
  2. 安全联动控制器
  3. 整合安全检测、事件推送和语音播报功能
  4. """
  5. import time
  6. import threading
  7. import queue
  8. from typing import Optional, List, Dict, Tuple, Callable
  9. from dataclasses import dataclass
  10. from enum import Enum
  11. import numpy as np
  12. import cv2
  13. from config import (
  14. COORDINATOR_CONFIG,
  15. SAFETY_DETECTION_CONFIG,
  16. EVENT_PUSHER_CONFIG,
  17. EVENT_LISTENER_CONFIG,
  18. VOICE_ANNOUNCER_CONFIG,
  19. SYSTEM_CONFIG
  20. )
  21. from safety_detector import (
  22. SafetyDetector, SafetyDetection, PersonSafetyStatus,
  23. SafetyViolationType, draw_safety_result
  24. )
  25. from event_pusher import EventPusher, EventListener, SafetyEvent, EventType
  26. from voice_announcer import VoiceAnnouncer, VoicePriority
  27. class CoordinatorState(Enum):
  28. """控制器状态"""
  29. IDLE = 0 # 空闲
  30. DETECTING = 1 # 检测中
  31. TRACKING = 2 # 跟踪中
  32. ALERTING = 3 # 告警中
  33. @dataclass
  34. class AlertRecord:
  35. """告警记录"""
  36. track_id: int # 跟踪ID
  37. violation_type: str # 违规类型
  38. description: str # 描述
  39. frame: Optional[np.ndarray] # 图像
  40. timestamp: float # 时间戳
  41. pushed: bool = False # 是否已推送
  42. announced: bool = False # 是否已播报
  43. class SafetyCoordinator:
  44. """安全联动控制器:协调摄像头、安全检测、事件推送、语音播报、PTZ跟踪"""
  45. def __init__(self, camera, config: Dict = None, ptz_camera=None, calibrator=None):
  46. self.camera = camera
  47. self.config = config or {}
  48. self.ptz = ptz_camera # PTZ球机(可选)
  49. self.calibrator = calibrator # 校准器(可选)
  50. self.detector = None
  51. self.event_pusher = None
  52. self.voice_announcer = None
  53. self.event_listener = None
  54. self.state = CoordinatorState.IDLE
  55. self.state_lock = threading.Lock()
  56. self.running = False
  57. self.worker_thread = None
  58. # PTZ跟踪线程(独立于检测线程)
  59. self._ptz_thread = None
  60. self._ptz_queue: queue.Queue = queue.Queue(maxsize=10)
  61. self._ptz_cooldown = 0.15
  62. self._last_ptz_time = 0.0
  63. # 跟踪状态
  64. self.tracks = {}
  65. self.next_track_id = 1
  66. self.alert_records: List[AlertRecord] = []
  67. self.alert_cooldown = {}
  68. self.stats = {
  69. 'frames_processed': 0,
  70. 'persons_detected': 0,
  71. 'violations_detected': 0,
  72. 'events_pushed': 0,
  73. 'voice_announced': 0,
  74. 'ptz_commands_sent': 0,
  75. 'start_time': None
  76. }
  77. self.stats_lock = threading.Lock()
  78. self.on_violation_detected: Optional[Callable] = None
  79. self.on_frame_processed: Optional[Callable] = None
  80. self._init_components()
  81. def _init_components(self):
  82. """初始化各组件"""
  83. # 从 SYSTEM_CONFIG 读取功能开关
  84. enable_detection = SYSTEM_CONFIG.get('enable_detection', True)
  85. enable_safety_detection = SYSTEM_CONFIG.get('enable_safety_detection', True)
  86. enable_event_push = SYSTEM_CONFIG.get('enable_event_push', True)
  87. enable_voice_announce = SYSTEM_CONFIG.get('enable_voice_announce', True)
  88. # 安全检测器
  89. if enable_detection and enable_safety_detection:
  90. try:
  91. self.detector = SafetyDetector(
  92. model_path=SAFETY_DETECTION_CONFIG.get('model_path'),
  93. use_gpu=SAFETY_DETECTION_CONFIG.get('use_gpu', True),
  94. conf_threshold=SAFETY_DETECTION_CONFIG.get('conf_threshold', 0.5),
  95. person_threshold=SAFETY_DETECTION_CONFIG.get('person_threshold', 0.8)
  96. )
  97. print("安全检测器初始化成功")
  98. except Exception as e:
  99. print(f"安全检测器初始化失败: {e}")
  100. else:
  101. print("安全检测功能已禁用")
  102. # 事件推送器
  103. if enable_event_push:
  104. try:
  105. self.event_pusher = EventPusher(EVENT_PUSHER_CONFIG)
  106. print("事件推送器初始化成功")
  107. except Exception as e:
  108. print(f"事件推送器初始化失败: {e}")
  109. else:
  110. print("事件推送功能已禁用")
  111. # 语音播报器
  112. if enable_voice_announce:
  113. try:
  114. self.voice_announcer = VoiceAnnouncer(
  115. tts_config=VOICE_ANNOUNCER_CONFIG.get('tts', {}),
  116. player_config=VOICE_ANNOUNCER_CONFIG.get('player', {})
  117. )
  118. print("语音播报器初始化成功")
  119. except Exception as e:
  120. print(f"语音播报器初始化失败: {e}")
  121. else:
  122. print("语音播报功能已禁用")
  123. # 事件监听器
  124. if EVENT_LISTENER_CONFIG.get('enabled', True):
  125. try:
  126. self.event_listener = EventListener(EVENT_LISTENER_CONFIG)
  127. # 设置语音播放回调
  128. self.event_listener.set_voice_callback(self._on_voice_command)
  129. print("事件监听器初始化成功")
  130. except Exception as e:
  131. print(f"事件监听器初始化失败: {e}")
  132. def _on_voice_command(self, cmd: Dict):
  133. """处理语音播放指令"""
  134. if not self.voice_announcer:
  135. return
  136. text = cmd.get('text', '')
  137. priority = VoicePriority(cmd.get('priority', 2))
  138. if text:
  139. self.voice_announcer.announce(text, priority=priority)
  140. def start(self) -> bool:
  141. """启动控制器"""
  142. if self.running:
  143. return True
  144. if self.event_pusher:
  145. self.event_pusher.start()
  146. if self.voice_announcer:
  147. self.voice_announcer.start()
  148. if self.event_listener:
  149. self.event_listener.start()
  150. self.running = True
  151. self.worker_thread = threading.Thread(target=self._worker, daemon=True)
  152. self.worker_thread.start()
  153. # 启动 PTZ 跟踪线程(如果 PTZ 可用)
  154. if self.ptz and SYSTEM_CONFIG.get('enable_ptz_tracking', True):
  155. self._ptz_thread = threading.Thread(target=self._ptz_worker, daemon=True)
  156. self._ptz_thread.start()
  157. print("[SafetyCoordinator] PTZ跟踪线程已启动")
  158. with self.stats_lock:
  159. self.stats['start_time'] = time.time()
  160. print("安全联动控制器已启动")
  161. return True
  162. def stop(self):
  163. """停止控制器"""
  164. self.running = False
  165. if self.worker_thread:
  166. self.worker_thread.join(timeout=3)
  167. # 停止 PTZ 跟踪线程
  168. if self._ptz_thread:
  169. self._ptz_thread.join(timeout=2)
  170. self._ptz_thread = None
  171. if self.event_pusher:
  172. self.event_pusher.stop()
  173. if self.voice_announcer:
  174. self.voice_announcer.stop()
  175. if self.event_listener:
  176. self.event_listener.stop()
  177. self._print_stats()
  178. print("安全联动控制器已停止")
  179. def _worker(self):
  180. """工作线程"""
  181. # 优先使用 detection_fps,默认每秒2帧
  182. detection_fps = SAFETY_DETECTION_CONFIG.get('detection_fps', 2)
  183. detection_interval = 1.0 / detection_fps # 根据FPS计算间隔
  184. last_detection_time = 0
  185. detection_run_count = 0
  186. detection_violation_count = 0
  187. frame_count = 0
  188. last_log_time = time.time()
  189. heartbeat_interval = 30.0
  190. last_no_detect_log_time = 0
  191. import logging
  192. sc_logger = logging.getLogger(__name__)
  193. if self.detector is None:
  194. sc_logger.warning("[安全检测] ⚠️ 安全检测器未初始化! 安全检测不可用")
  195. else:
  196. sc_logger.info(f"[安全检测] ✓ 安全检测器已就绪, 检测帧率={detection_fps}fps(间隔={detection_interval:.2f}s)")
  197. while self.running:
  198. try:
  199. current_time = time.time()
  200. frame = self.camera.get_frame() if self.camera else None
  201. if frame is None:
  202. time.sleep(0.01)
  203. continue
  204. frame_count += 1
  205. self._update_stats('frames_processed')
  206. if current_time - last_log_time >= heartbeat_interval:
  207. stats = self.get_stats()
  208. state_str = self.state.name if hasattr(self.state, 'name') else str(self.state)
  209. sc_logger.info(
  210. f"[安全检测] 状态={state_str}, "
  211. f"检测轮次={detection_run_count}(有人={detection_violation_count}), "
  212. f"帧数={frame_count}"
  213. )
  214. frame_count = 0
  215. last_log_time = current_time
  216. if current_time - last_detection_time >= detection_interval:
  217. last_detection_time = current_time
  218. detection_run_count += 1
  219. result = self._process_frame_with_logging(frame, detection_run_count, detection_violation_count, last_no_detect_log_time, sc_logger)
  220. detection_violation_count = result
  221. self._cleanup_tracks()
  222. time.sleep(0.01)
  223. except Exception as e:
  224. sc_logger.error(f"[安全检测] 处理错误: {e}")
  225. time.sleep(0.1)
  226. def _process_frame_with_logging(self, frame: np.ndarray, run_count: int, violation_count: int, last_no_detect_time: float, sc_logger) -> int:
  227. """处理帧并返回更新的violation_count"""
  228. if self.detector is None:
  229. return violation_count
  230. self._set_state(CoordinatorState.DETECTING)
  231. detections = self.detector.detect(frame)
  232. status_list = self.detector.check_safety(frame, detections)
  233. self._update_stats('persons_detected', len(status_list))
  234. self._update_tracks(detections)
  235. has_violation = False
  236. for status in status_list:
  237. if status.is_violation:
  238. self._handle_violation(status, frame)
  239. has_violation = True
  240. if has_violation:
  241. violation_count += 1
  242. if not status_list:
  243. current_time = time.time()
  244. if current_time - last_no_detect_time >= 30.0:
  245. sc_logger.info(
  246. f"[安全检测] · YOLO检测运行正常, 本轮未检测到人员 "
  247. f"(累计检测{run_count}轮, 违规{violation_count}轮)"
  248. )
  249. if self.on_frame_processed:
  250. self.on_frame_processed(frame, detections, status_list)
  251. return violation_count
  252. def _process_frame(self, frame: np.ndarray):
  253. """处理帧"""
  254. if self.detector is None:
  255. return
  256. self._set_state(CoordinatorState.DETECTING)
  257. # 安全检测
  258. detections = self.detector.detect(frame)
  259. status_list = self.detector.check_safety(frame, detections)
  260. self._update_stats('persons_detected', len(status_list))
  261. # 更新跟踪
  262. self._update_tracks(detections)
  263. # 检查违规
  264. for status in status_list:
  265. if status.is_violation:
  266. self._handle_violation(status, frame)
  267. # 回调
  268. if self.on_frame_processed:
  269. self.on_frame_processed(frame, detections, status_list)
  270. def _update_tracks(self, detections: List[SafetyDetection]):
  271. """更新跟踪状态"""
  272. current_time = time.time()
  273. persons = [d for d in detections if d.class_id == 3] # 人
  274. # 匹配现有跟踪
  275. used_ids = set()
  276. for person in persons:
  277. best_id = None
  278. min_dist = float('inf')
  279. for track_id, track in self.tracks.items():
  280. if track_id in used_ids:
  281. continue
  282. dist = np.sqrt(
  283. (person.center[0] - track['center'][0])**2 +
  284. (person.center[1] - track['center'][1])**2
  285. )
  286. if dist < min_dist and dist < 100: # 距离阈值
  287. min_dist = dist
  288. best_id = track_id
  289. if best_id is not None:
  290. # 更新现有跟踪
  291. self.tracks[best_id]['center'] = person.center
  292. self.tracks[best_id]['last_update'] = current_time
  293. person.track_id = best_id
  294. used_ids.add(best_id)
  295. else:
  296. # 新跟踪
  297. track_id = self.next_track_id
  298. self.next_track_id += 1
  299. person.track_id = track_id
  300. self.tracks[track_id] = {
  301. 'center': person.center,
  302. 'last_update': current_time,
  303. 'alerts': []
  304. }
  305. def _cleanup_tracks(self):
  306. """清理过期跟踪"""
  307. current_time = time.time()
  308. timeout = COORDINATOR_CONFIG.get('tracking_timeout', 5.0)
  309. expired = [
  310. tid for tid, t in self.tracks.items()
  311. if current_time - t['last_update'] > timeout
  312. ]
  313. for tid in expired:
  314. del self.tracks[tid]
  315. self.alert_cooldown.pop(tid, None)
  316. def _handle_violation(self, status: PersonSafetyStatus, frame: np.ndarray):
  317. """处理违规"""
  318. current_time = time.time()
  319. track_id = status.track_id
  320. # 检查冷却时间
  321. cooldown = SAFETY_DETECTION_CONFIG.get('alert_cooldown', 3.0)
  322. if track_id in self.alert_cooldown:
  323. if current_time - self.alert_cooldown[track_id] < cooldown:
  324. return
  325. # 记录告警
  326. self.alert_cooldown[track_id] = current_time
  327. description = status.get_violation_desc()
  328. violation_type = status.violation_types[0].value if status.violation_types else "未知"
  329. # 裁剪人体区域
  330. x1, y1, x2, y2 = status.person_bbox
  331. margin = 20
  332. x1 = max(0, x1 - margin)
  333. y1 = max(0, y1 - margin)
  334. x2 = min(frame.shape[1], x2 + margin)
  335. y2 = min(frame.shape[0], y2 + margin)
  336. person_image = frame[y1:y2, x1:x2].copy()
  337. record = AlertRecord(
  338. track_id=track_id,
  339. violation_type=violation_type,
  340. description=description,
  341. frame=person_image,
  342. timestamp=current_time
  343. )
  344. self.alert_records.append(record)
  345. self._update_stats('violations_detected')
  346. # PTZ 跟踪违规人员(如果 PTZ 可用且启用)
  347. if self.ptz and SYSTEM_CONFIG.get('enable_ptz_tracking', True):
  348. self._track_violator_ptz(status, frame)
  349. # 回调
  350. if self.on_violation_detected:
  351. self.on_violation_detected(status, frame)
  352. # 推送事件
  353. if self.event_pusher:
  354. self.event_pusher.push_safety_violation(
  355. description=description,
  356. image=person_image,
  357. track_id=track_id,
  358. confidence=status.person_conf
  359. )
  360. self._update_stats('events_pushed')
  361. # 语音播报
  362. if self.voice_announcer:
  363. self.voice_announcer.announce_violation(description, urgent=True)
  364. self._update_stats('voice_announced')
  365. print(f"[告警] {description}, 跟踪ID: {track_id}")
  366. def _track_violator_ptz(self, status: PersonSafetyStatus, frame: np.ndarray):
  367. """违规人员PTZ跟踪:将违规人员在全景画面中的位置发送给PTZ线程"""
  368. if self.ptz is None:
  369. return
  370. frame_h, frame_w = frame.shape[:2]
  371. x1, y1, x2, y2 = status.person_bbox
  372. # 计算违规人员在全景画面中的相对位置
  373. center_x = (x1 + x2) / 2
  374. center_y = (y1 + y2) / 2
  375. x_ratio = center_x / frame_w
  376. y_ratio = center_y / frame_h
  377. # 冷却检查
  378. current_time = time.time()
  379. if current_time - self._last_ptz_time < self._ptz_cooldown:
  380. return
  381. # 发送PTZ命令
  382. try:
  383. self._ptz_queue.put_nowait({
  384. 'x_ratio': x_ratio,
  385. 'y_ratio': y_ratio,
  386. 'track_id': status.track_id,
  387. 'violation_type': status.violation_types[0].value if status.violation_types else 'unknown'
  388. })
  389. self._last_ptz_time = current_time
  390. self._update_stats('ptz_commands_sent')
  391. except queue.Full:
  392. pass # 队列满则丢弃,下一个检测周期会重发
  393. def _ptz_worker(self):
  394. """PTZ控制工作线程:独立处理所有PTZ命令"""
  395. while self.running:
  396. try:
  397. try:
  398. cmd = self._ptz_queue.get(timeout=0.1)
  399. except queue.Empty:
  400. continue
  401. if self.ptz is None:
  402. continue
  403. x_ratio = cmd['x_ratio']
  404. y_ratio = cmd['y_ratio']
  405. # 使用校准器转换坐标,或使用估算
  406. if self.calibrator and self.calibrator.is_calibrated():
  407. pan, tilt = self.calibrator.transform(x_ratio, y_ratio)
  408. zoom = self.ptz.ptz_config.get('default_zoom', 8)
  409. if self.ptz.ptz_config.get('pan_flip', False):
  410. pan = (pan + 180) % 360
  411. self.ptz.goto_exact_position(pan, tilt, zoom)
  412. else:
  413. self.ptz.track_target(x_ratio, y_ratio)
  414. except Exception as e:
  415. print(f"[SafetyCoordinator] PTZ跟踪错误: {e}")
  416. time.sleep(0.05)
  417. def _set_state(self, state: CoordinatorState):
  418. """设置状态"""
  419. with self.state_lock:
  420. self.state = state
  421. def get_state(self) -> CoordinatorState:
  422. """获取状态"""
  423. with self.state_lock:
  424. return self.state
  425. def _update_stats(self, key: str, value: int = 1):
  426. """更新统计"""
  427. with self.stats_lock:
  428. if key in self.stats:
  429. self.stats[key] += value
  430. def _print_stats(self):
  431. """打印统计"""
  432. with self.stats_lock:
  433. if self.stats['start_time']:
  434. elapsed = time.time() - self.stats['start_time']
  435. print("\n=== 安全检测统计 ===")
  436. print(f"运行时长: {elapsed:.1f}秒")
  437. print(f"处理帧数: {self.stats['frames_processed']}")
  438. print(f"检测人员: {self.stats['persons_detected']}次")
  439. print(f"违规检测: {self.stats['violations_detected']}次")
  440. print(f"事件推送: {self.stats['events_pushed']}次")
  441. print(f"语音播报: {self.stats['voice_announced']}次")
  442. if self.event_pusher:
  443. push_stats = self.event_pusher.get_stats()
  444. print(f"推送详情: 成功{push_stats['pushed_events']}, 失败{push_stats['failed_events']}")
  445. if self.voice_announcer:
  446. voice_stats = self.voice_announcer.get_stats()
  447. print(f"播报详情: 成功{voice_stats['played_commands']}, 失败{voice_stats['failed_commands']}")
  448. print("===================\n")
  449. def get_stats(self) -> Dict:
  450. """获取统计"""
  451. with self.stats_lock:
  452. return self.stats.copy()
  453. def get_alerts(self) -> List[AlertRecord]:
  454. """获取告警记录"""
  455. return self.alert_records.copy()
  456. def announce(self, text: str, priority: VoicePriority = VoicePriority.NORMAL):
  457. """
  458. 手动播报语音
  459. Args:
  460. text: 播报文本
  461. priority: 优先级
  462. """
  463. if self.voice_announcer:
  464. self.voice_announcer.announce(text, priority=priority)
  465. def force_detect(self, frame: np.ndarray = None) -> Tuple[List[SafetyDetection], List[PersonSafetyStatus]]:
  466. """
  467. 强制执行一次检测
  468. Args:
  469. frame: 输入帧,如果为 None 则从摄像头获取
  470. Returns:
  471. (检测结果, 安全状态列表)
  472. """
  473. if frame is None:
  474. frame = self.camera.get_frame() if self.camera else None
  475. if frame is None or self.detector is None:
  476. return [], []
  477. detections = self.detector.detect(frame)
  478. status_list = self.detector.check_safety(frame, detections)
  479. return detections, status_list
  480. class SimpleCamera:
  481. """简单摄像头封装(用于测试)"""
  482. def __init__(self, source=0):
  483. """
  484. 初始化摄像头
  485. Args:
  486. source: 视频源 (摄像头索引、RTSP地址、视频文件路径)
  487. """
  488. self.cap = None
  489. self.source = source
  490. self.connected = False
  491. def connect(self) -> bool:
  492. """连接摄像头"""
  493. try:
  494. # 使用 FFmpeg 单线程模式避免线程安全崩溃
  495. import os
  496. os.environ['OPENCV_FFMPEG_CAPTURE_OPTIONS'] = 'threads;1'
  497. self.cap = cv2.VideoCapture(self.source, cv2.CAP_FFMPEG)
  498. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
  499. self.connected = self.cap.isOpened()
  500. return self.connected
  501. except Exception as e:
  502. print(f"连接摄像头失败: {e}")
  503. return False
  504. def disconnect(self):
  505. """断开连接"""
  506. if self.cap:
  507. self.cap.release()
  508. self.connected = False
  509. def get_frame(self) -> Optional[np.ndarray]:
  510. """获取帧"""
  511. if self.cap is None or not self.cap.isOpened():
  512. return None
  513. ret, frame = self.cap.read()
  514. return frame if ret else None
  515. def create_coordinator(camera_source=0, config: Dict = None) -> SafetyCoordinator:
  516. """
  517. 创建安全联动控制器
  518. Args:
  519. camera_source: 摄像头源
  520. config: 配置
  521. Returns:
  522. SafetyCoordinator 实例
  523. """
  524. camera = SimpleCamera(camera_source)
  525. return SafetyCoordinator(camera, config)