polling_tracker.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. """
  2. 轮询跟踪 + PTZ 抓拍协调器
  3. """
  4. import os
  5. import time
  6. import threading
  7. import logging
  8. from typing import Dict, List, Tuple, Optional
  9. from dataclasses import dataclass
  10. import cv2
  11. import numpy as np
  12. from config import TRACKING_CONFIG
  13. from tracker import UltralyticsTracker, TrackedPerson
  14. from coordinator import TargetSelector, TrackingTarget
  15. logger = logging.getLogger(__name__)
  16. @dataclass
  17. class CaptureRecord:
  18. """单次抓拍记录"""
  19. track_id: int
  20. timestamp: float
  21. position: Tuple[float, float]
  22. ptz_position: Tuple[float, float, int]
  23. ptz_image: np.ndarray
  24. panorama_image: Optional[np.ndarray]
  25. confidence: float
  26. class PollingTrackingCoordinator:
  27. """多目标轮询跟踪 + PTZ 抓拍协调器"""
  28. def __init__(
  29. self,
  30. panorama_camera,
  31. ptz_camera,
  32. tracker: UltralyticsTracker,
  33. config: Optional[Dict] = None,
  34. calibrator=None,
  35. ):
  36. self.panorama = panorama_camera
  37. self.ptz = ptz_camera
  38. self.tracker = tracker
  39. self.config = config or TRACKING_CONFIG
  40. self.calibrator = calibrator
  41. self.active_targets: Dict[int, TrackedPerson] = {}
  42. self.target_order: List[int] = []
  43. self.current_index: int = 0
  44. self.batch_captures: List[CaptureRecord] = []
  45. self._capture_counts: Dict[int, int] = {}
  46. self._last_seen_time: Dict[int, float] = {}
  47. self.targets_lock = threading.Lock()
  48. self.batch_lock = threading.Lock()
  49. self._capture_counts_lock = threading.Lock()
  50. self.running = False
  51. self._detection_thread = None
  52. self._ptz_thread = None
  53. self._paused = False
  54. self._paused_event = threading.Event()
  55. self._paused_event.set()
  56. self._last_ptz_command_time = 0.0
  57. self._last_ptz_command_time_lock = threading.Lock()
  58. self.target_selector = TargetSelector(self.config.get("target_selection", {}))
  59. self.event_pusher = None
  60. self.stats = {
  61. "frames_processed": 0,
  62. "persons_detected": 0,
  63. "captures": 0,
  64. "uploads": 0,
  65. "start_time": None,
  66. }
  67. self.stats_lock = threading.Lock()
  68. self._ensure_capture_dir()
  69. def set_event_pusher(self, event_pusher):
  70. self.event_pusher = event_pusher
  71. def _ensure_capture_dir(self):
  72. capture_dir = self.config.get("capture_dir", "/home/admin/dsh/tracking_captures")
  73. try:
  74. os.makedirs(capture_dir, exist_ok=True)
  75. except OSError as e:
  76. logger.warning(f"无法创建抓拍目录 {capture_dir}: {e}")
  77. def start(self) -> bool:
  78. if not self.panorama.connect():
  79. logger.error("连接全景摄像头失败")
  80. return False
  81. if not self.ptz.connect():
  82. logger.error("连接球机失败")
  83. self.panorama.disconnect()
  84. return False
  85. if not self.panorama.start_stream_rtsp():
  86. logger.error("启动全景视频流失败")
  87. self.panorama.disconnect()
  88. self.ptz.disconnect()
  89. return False
  90. self.running = True
  91. self._detection_thread = threading.Thread(target=self._detection_worker, daemon=True)
  92. self._detection_thread.start()
  93. self._ptz_thread = threading.Thread(target=self._ptz_worker, daemon=True)
  94. self._ptz_thread.start()
  95. with self.stats_lock:
  96. self.stats["start_time"] = time.time()
  97. logger.info("轮询跟踪抓拍协调器已启动")
  98. return True
  99. def stop(self):
  100. self.running = False
  101. self._paused_event.set()
  102. if self._detection_thread:
  103. self._detection_thread.join(timeout=3)
  104. if self._ptz_thread:
  105. self._ptz_thread.join(timeout=3)
  106. # 刷新待上传批次
  107. self._flush_batch_if_needed()
  108. # 停止视频流后再断开连接
  109. if hasattr(self.panorama, "stop_stream_rtsp"):
  110. try:
  111. self.panorama.stop_stream_rtsp()
  112. except Exception as e:
  113. logger.warning(f"停止全景视频流失败: {e}")
  114. self.panorama.disconnect()
  115. self.ptz.disconnect()
  116. # 释放跟踪器资源
  117. if self.tracker is not None and hasattr(self.tracker, "release"):
  118. try:
  119. self.tracker.release()
  120. except Exception as e:
  121. logger.warning(f"释放跟踪器失败: {e}")
  122. logger.info("轮询跟踪抓拍协调器已停止")
  123. def pause(self):
  124. self._paused = True
  125. self._paused_event.clear()
  126. def resume(self):
  127. self._paused = False
  128. self._paused_event.set()
  129. def pause_detection(self):
  130. """暂停检测(兼容既有协调器接口)"""
  131. self.pause()
  132. def resume_detection(self):
  133. """恢复检测(兼容既有协调器接口)"""
  134. self.resume()
  135. def _detection_worker(self):
  136. self._paused_event.wait()
  137. detection_fps = self.config.get("detection_fps", 2)
  138. detection_interval = 1.0 / detection_fps
  139. last_detection_time = 0
  140. while self.running:
  141. try:
  142. # 暂停时阻塞等待,避免忙等
  143. if self._paused:
  144. self._paused_event.wait()
  145. continue
  146. frame = self.panorama.get_frame()
  147. if frame is None:
  148. time.sleep(0.01)
  149. continue
  150. self._update_stats("frames_processed")
  151. current_time = time.time()
  152. if current_time - last_detection_time >= detection_interval:
  153. last_detection_time = current_time
  154. tracked = self.tracker.update(frame)
  155. self._update_active_targets(tracked, frame.shape)
  156. if tracked:
  157. self._update_stats("persons_detected", len(tracked))
  158. time.sleep(0.01)
  159. except Exception as e:
  160. logger.error(f"检测线程错误: {e}")
  161. time.sleep(0.1)
  162. def _update_active_targets(self, tracked: List[TrackedPerson], frame_shape):
  163. current_time = time.time()
  164. frame_h, frame_w = frame_shape[:2]
  165. timeout = self.config.get("tracking_timeout", 3.0)
  166. max_targets = self.config.get("max_tracking_targets", 4)
  167. with self.targets_lock:
  168. # 更新或新增
  169. updated_ids = set()
  170. for p in tracked:
  171. if p.track_id < 0:
  172. continue
  173. updated_ids.add(p.track_id)
  174. p.lost = False
  175. self.active_targets[p.track_id] = p
  176. self._last_seen_time[p.track_id] = current_time
  177. if p.track_id not in self.target_order:
  178. self.target_order.append(p.track_id)
  179. # 标记丢失
  180. for tid in self.target_order:
  181. if tid not in updated_ids:
  182. t = self.active_targets.get(tid)
  183. if t is not None:
  184. t.lost = True
  185. # 移除长期丢失(超过 tracking_timeout)
  186. remove_ids = []
  187. for tid in self.target_order:
  188. t = self.active_targets.get(tid)
  189. if t is None:
  190. remove_ids.append(tid)
  191. continue
  192. if t.lost:
  193. last_seen = self._last_seen_time.get(tid, current_time)
  194. if current_time - last_seen >= timeout:
  195. remove_ids.append(tid)
  196. for tid in remove_ids:
  197. if tid in self.target_order:
  198. self.target_order.remove(tid)
  199. self.active_targets.pop(tid, None)
  200. self._last_seen_time.pop(tid, None)
  201. with self._capture_counts_lock:
  202. self._capture_counts.pop(tid, None)
  203. # 人数上限淘汰
  204. if len(self.active_targets) > max_targets:
  205. self._prune_targets(frame_w, frame_h, max_targets)
  206. def _prune_targets(self, frame_w: int, frame_h: int, max_targets: int):
  207. targets = list(self.active_targets.values())
  208. frame_size = (frame_w, frame_h)
  209. scored = []
  210. for t in targets:
  211. area = (t.bbox[2] - t.bbox[0]) * (t.bbox[3] - t.bbox[1])
  212. center_distance = self._center_distance(t.center, frame_size)
  213. target = TrackingTarget(
  214. track_id=t.track_id,
  215. position=(t.center[0] / frame_w, t.center[1] / frame_h),
  216. last_update=time.time(),
  217. area=area,
  218. confidence=t.confidence,
  219. center_distance=center_distance,
  220. )
  221. target.score = self.target_selector.calculate_score(target, frame_size)
  222. scored.append(target)
  223. scored.sort(key=lambda x: x.score, reverse=True)
  224. keep_ids = {t.track_id for t in scored[:max_targets]}
  225. remove_ids = [tid for tid in self.active_targets if tid not in keep_ids]
  226. for tid in remove_ids:
  227. self.active_targets.pop(tid, None)
  228. self._last_seen_time.pop(tid, None)
  229. if tid in self.target_order:
  230. self.target_order.remove(tid)
  231. with self._capture_counts_lock:
  232. self._capture_counts.pop(tid, None)
  233. def _center_distance(self, center: Tuple[int, int], frame_size: Tuple[int, int]) -> float:
  234. cx, cy = frame_size[0] / 2, frame_size[1] / 2
  235. dx = abs(center[0] - cx) / cx
  236. dy = abs(center[1] - cy) / cy
  237. return (dx + dy) / 2
  238. def _ptz_worker(self):
  239. while self.running:
  240. try:
  241. # 暂停时阻塞等待恢复
  242. if self._paused:
  243. self._paused_event.wait()
  244. continue
  245. # 原子性获取目标快照和当前目标
  246. with self.targets_lock:
  247. target_order_snapshot = self.target_order.copy()
  248. has_targets = bool(self.active_targets)
  249. if target_order_snapshot:
  250. if self.current_index >= len(target_order_snapshot):
  251. self.current_index = 0
  252. target_id = target_order_snapshot[self.current_index]
  253. target = self.active_targets.get(target_id)
  254. else:
  255. target_id = None
  256. target = None
  257. if not has_targets or not target_order_snapshot:
  258. self._flush_batch_if_needed()
  259. time.sleep(0.1)
  260. continue
  261. if target is None or target.lost:
  262. # 目标丢失时跳过但保留在队列中,并短暂休眠避免忙等
  263. time.sleep(0.01)
  264. with self.targets_lock:
  265. self._advance(len(target_order_snapshot))
  266. continue
  267. record = self._capture_one(target)
  268. if record:
  269. with self.batch_lock:
  270. self.batch_captures.append(record)
  271. self._update_stats("captures")
  272. with self.targets_lock:
  273. self._advance(len(target_order_snapshot))
  274. # 一轮完成
  275. with self.batch_lock:
  276. if self.current_index == 0 and self.batch_captures:
  277. self._upload_batch(self.batch_captures)
  278. self.batch_captures.clear()
  279. except Exception as e:
  280. logger.error(f"PTZ 线程错误: {e}")
  281. time.sleep(0.1)
  282. def _advance(self, order_len: int = None):
  283. if order_len is None:
  284. order_len = len(self.target_order) or 1
  285. self.current_index = (self.current_index + 1) % order_len
  286. def _capture_one(self, target: TrackedPerson) -> Optional[CaptureRecord]:
  287. frame = self.panorama.get_frame()
  288. if frame is None:
  289. return None
  290. frame_h, frame_w = frame.shape[:2]
  291. x_ratio = target.center[0] / frame_w
  292. y_ratio = target.center[1] / frame_h
  293. if self.calibrator and self.calibrator.is_calibrated():
  294. # 校准器已通过真实扫描建立全景坐标到球机 PTZ 角度的映射,
  295. # 返回的角度可直接发送给球机,不应再应用 pan_flip。
  296. pan, tilt = self.calibrator.transform(x_ratio, y_ratio)
  297. ptz_config = getattr(self.ptz, "ptz_config", {})
  298. zoom = ptz_config.get("default_zoom", 8)
  299. else:
  300. pan, tilt, zoom = self.ptz.calculate_ptz_position(x_ratio, y_ratio)
  301. # enforce PTZ command cooldown
  302. ptz_command_cooldown = self.config.get("ptz_command_cooldown", 0.2)
  303. with self._last_ptz_command_time_lock:
  304. elapsed = time.time() - self._last_ptz_command_time
  305. if elapsed < ptz_command_cooldown:
  306. time.sleep(ptz_command_cooldown - elapsed)
  307. success = self.ptz.goto_exact_position(pan, tilt, zoom)
  308. if not success:
  309. return None
  310. with self._last_ptz_command_time_lock:
  311. self._last_ptz_command_time = time.time()
  312. ptz_stabilize_time = self.config.get("ptz_stabilize_time", 2.0)
  313. time.sleep(max(ptz_stabilize_time, ptz_command_cooldown))
  314. ptz_frame = self._get_clear_ptz_frame()
  315. if ptz_frame is None:
  316. return None
  317. max_cap = self.config.get("max_capture_per_target", 0)
  318. with self._capture_counts_lock:
  319. if max_cap > 0 and self._capture_counts.get(target.track_id, 0) >= max_cap:
  320. return None
  321. self._capture_counts[target.track_id] = self._capture_counts.get(target.track_id, 0) + 1
  322. panorama_image = frame.copy() if self.config.get("save_panorama_pair", True) else None
  323. # 本地保存
  324. self._save_local(ptz_frame, panorama_image, target, pan, tilt, zoom)
  325. return CaptureRecord(
  326. track_id=target.track_id,
  327. timestamp=time.time(),
  328. position=(x_ratio, y_ratio),
  329. ptz_position=(pan, tilt, zoom),
  330. ptz_image=ptz_frame,
  331. panorama_image=panorama_image,
  332. confidence=target.confidence,
  333. )
  334. def _get_clear_ptz_frame(self, max_attempts: int = 5, wait_interval: float = 0.2) -> Optional[np.ndarray]:
  335. best_frame = None
  336. best_score = -1
  337. for _ in range(max_attempts):
  338. frame = self.ptz.get_frame()
  339. if frame is not None:
  340. frame_copy = frame.copy()
  341. gray = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2GRAY)
  342. score = cv2.Laplacian(gray, cv2.CV_64F).var()
  343. if score > best_score:
  344. best_score = score
  345. best_frame = frame_copy
  346. time.sleep(wait_interval)
  347. return best_frame
  348. def _save_local(self, ptz_frame, panorama_image, target, pan, tilt, zoom):
  349. capture_dir = self.config.get("capture_dir", "/home/admin/dsh/tracking_captures")
  350. try:
  351. os.makedirs(capture_dir, exist_ok=True)
  352. except OSError as e:
  353. logger.warning(f"无法创建抓拍目录 {capture_dir}: {e}")
  354. return
  355. timestamp = int(time.time() * 1000)
  356. base = f"{capture_dir}/ptz_{target.track_id}_{timestamp}_{pan:.0f}_{tilt:.0f}_z{zoom}.jpg"
  357. cv2.imwrite(base, ptz_frame)
  358. if panorama_image is not None:
  359. pan_base = f"{capture_dir}/panorama_{target.track_id}_{timestamp}.jpg"
  360. cv2.imwrite(pan_base, panorama_image)
  361. def _upload_batch(self, records: List[CaptureRecord]):
  362. if not self.event_pusher or not self.config.get("enable_upload", True):
  363. return
  364. try:
  365. uploads = []
  366. for r in records:
  367. ptz_url = self.event_pusher.upload_numpy_image(r.ptz_image)
  368. pan_url = None
  369. if r.panorama_image is not None:
  370. pan_url = self.event_pusher.upload_numpy_image(r.panorama_image)
  371. uploads.append({
  372. "track_id": r.track_id,
  373. "ptz_image_url": ptz_url,
  374. "panorama_image_url": pan_url,
  375. "position": r.position,
  376. "ptz_position": r.ptz_position,
  377. "confidence": r.confidence,
  378. "timestamp": r.timestamp,
  379. })
  380. self.event_pusher.push_tracking_capture(batch_time=time.time(), captures=uploads)
  381. self._update_stats("uploads")
  382. except Exception as e:
  383. logger.error(f"批量上传失败: {e}")
  384. def _flush_batch_if_needed(self):
  385. with self.batch_lock:
  386. if self.batch_captures:
  387. self._upload_batch(self.batch_captures)
  388. self.batch_captures.clear()
  389. def _update_stats(self, key: str, value: int = 1):
  390. with self.stats_lock:
  391. if key in self.stats:
  392. self.stats[key] += value
  393. def get_results(self) -> List[CaptureRecord]:
  394. """获取当前抓拍结果(兼容既有协调器接口)"""
  395. with self.batch_lock:
  396. return self.batch_captures.copy()
  397. def get_stats(self) -> Dict:
  398. with self.stats_lock:
  399. return self.stats.copy()