polling_tracker.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. """
  2. 轮询跟踪 + PTZ 抓拍协调器
  3. """
  4. import os
  5. import time
  6. import threading
  7. import logging
  8. from typing import Dict, List, Tuple, Optional
  9. from dataclasses import dataclass
  10. import cv2
  11. import numpy as np
  12. from config import TRACKING_CONFIG
  13. from tracker import UltralyticsTracker, TrackedPerson
  14. from coordinator import TargetSelector
  15. logger = logging.getLogger(__name__)
  16. @dataclass
  17. class CaptureRecord:
  18. """单次抓拍记录"""
  19. track_id: int
  20. timestamp: float
  21. position: Tuple[float, float]
  22. ptz_position: Tuple[float, float, int]
  23. ptz_image: np.ndarray
  24. panorama_image: Optional[np.ndarray]
  25. confidence: float
  26. class PollingTrackingCoordinator:
  27. """多目标轮询跟踪 + PTZ 抓拍协调器"""
  28. def __init__(
  29. self,
  30. panorama_camera,
  31. ptz_camera,
  32. tracker: UltralyticsTracker,
  33. config: Optional[Dict] = None,
  34. calibrator=None,
  35. ):
  36. self.panorama = panorama_camera
  37. self.ptz = ptz_camera
  38. self.tracker = tracker
  39. self.config = config or TRACKING_CONFIG
  40. self.calibrator = calibrator
  41. self.active_targets: Dict[int, TrackedPerson] = {}
  42. self.target_order: List[int] = []
  43. self.current_index: int = 0
  44. self.batch_captures: List[CaptureRecord] = []
  45. self._capture_counts: Dict[int, int] = {}
  46. self.targets_lock = threading.Lock()
  47. self.running = False
  48. self._detection_thread = None
  49. self._ptz_thread = None
  50. self._paused = False
  51. self._paused_event = threading.Event()
  52. self._paused_event.set()
  53. self.target_selector = TargetSelector(self.config.get("target_selection", {}))
  54. self.event_pusher = None
  55. self.stats = {
  56. "frames_processed": 0,
  57. "persons_detected": 0,
  58. "captures": 0,
  59. "uploads": 0,
  60. "start_time": None,
  61. }
  62. self.stats_lock = threading.Lock()
  63. self._ensure_capture_dir()
  64. def set_event_pusher(self, event_pusher):
  65. self.event_pusher = event_pusher
  66. def _ensure_capture_dir(self):
  67. capture_dir = self.config.get("capture_dir", "/home/admin/dsh/tracking_captures")
  68. try:
  69. os.makedirs(capture_dir, exist_ok=True)
  70. except OSError as e:
  71. logger.warning(f"无法创建抓拍目录 {capture_dir}: {e}")
  72. def start(self) -> bool:
  73. if not self.panorama.connect():
  74. print("连接全景摄像头失败")
  75. return False
  76. if not self.ptz.connect():
  77. print("连接球机失败")
  78. self.panorama.disconnect()
  79. return False
  80. if not self.panorama.start_stream_rtsp():
  81. print("启动全景视频流失败")
  82. self.panorama.disconnect()
  83. self.ptz.disconnect()
  84. return False
  85. self.running = True
  86. self._detection_thread = threading.Thread(target=self._detection_worker, daemon=True)
  87. self._detection_thread.start()
  88. self._ptz_thread = threading.Thread(target=self._ptz_worker, daemon=True)
  89. self._ptz_thread.start()
  90. with self.stats_lock:
  91. self.stats["start_time"] = time.time()
  92. print("轮询跟踪抓拍协调器已启动")
  93. return True
  94. def stop(self):
  95. self.running = False
  96. if self._detection_thread:
  97. self._detection_thread.join(timeout=3)
  98. if self._ptz_thread:
  99. self._ptz_thread.join(timeout=3)
  100. self.panorama.disconnect()
  101. self.ptz.disconnect()
  102. print("轮询跟踪抓拍协调器已停止")
  103. def pause(self):
  104. self._paused = True
  105. self._paused_event.clear()
  106. def resume(self):
  107. self._paused = False
  108. self._paused_event.set()
  109. def _detection_worker(self):
  110. self._paused_event.wait()
  111. detection_fps = self.config.get("detection_fps", 2)
  112. detection_interval = 1.0 / detection_fps
  113. last_detection_time = 0
  114. while self.running:
  115. try:
  116. frame = self.panorama.get_frame()
  117. if frame is None:
  118. time.sleep(0.01)
  119. continue
  120. self._update_stats("frames_processed")
  121. current_time = time.time()
  122. if not self._paused and current_time - last_detection_time >= detection_interval:
  123. last_detection_time = current_time
  124. tracked = self.tracker.update(frame)
  125. self._update_active_targets(tracked, frame.shape)
  126. if tracked:
  127. self._update_stats("persons_detected", len(tracked))
  128. time.sleep(0.01)
  129. except Exception as e:
  130. logger.error(f"检测线程错误: {e}")
  131. time.sleep(0.1)
  132. def _update_active_targets(self, tracked: List[TrackedPerson], frame_shape):
  133. current_time = time.time()
  134. frame_h, frame_w = frame_shape[:2]
  135. timeout = self.config.get("tracking_timeout", 3.0)
  136. max_targets = self.config.get("max_tracking_targets", 4)
  137. with self.targets_lock:
  138. # 更新或新增
  139. updated_ids = set()
  140. for p in tracked:
  141. if p.track_id < 0:
  142. continue
  143. updated_ids.add(p.track_id)
  144. self.active_targets[p.track_id] = p
  145. if p.track_id not in self.target_order:
  146. self.target_order.append(p.track_id)
  147. # 标记丢失
  148. lost_ids = [tid for tid in self.target_order if tid not in updated_ids]
  149. for tid in lost_ids:
  150. t = self.active_targets.get(tid)
  151. if t is not None:
  152. t.lost = True
  153. # 移除长期丢失
  154. remove_ids = []
  155. for tid in self.target_order:
  156. t = self.active_targets.get(tid)
  157. if t is None:
  158. remove_ids.append(tid)
  159. continue
  160. # 简单丢失超时移除
  161. if t.lost:
  162. remove_ids.append(tid)
  163. for tid in remove_ids:
  164. if tid in self.target_order:
  165. self.target_order.remove(tid)
  166. self.active_targets.pop(tid, None)
  167. self._capture_counts.pop(tid, None)
  168. # 人数上限淘汰
  169. if len(self.active_targets) > max_targets:
  170. self._prune_targets(frame_w, frame_h, max_targets)
  171. def _prune_targets(self, frame_w: int, frame_h: int, max_targets: int):
  172. targets = list(self.active_targets.values())
  173. frame_size = (frame_w, frame_h)
  174. scored = []
  175. for t in targets:
  176. target_wrapper = type("T", (), {
  177. "track_id": t.track_id,
  178. "area": (t.bbox[2] - t.bbox[0]) * (t.bbox[3] - t.bbox[1]),
  179. "confidence": t.confidence,
  180. "center_distance": self._center_distance(t.center, frame_size),
  181. "score": 0.0,
  182. })()
  183. target_wrapper.score = self.target_selector.calculate_score(target_wrapper, frame_size)
  184. scored.append(target_wrapper)
  185. scored.sort(key=lambda x: x.score, reverse=True)
  186. keep_ids = {t.track_id for t in scored[:max_targets]}
  187. remove_ids = [tid for tid in self.active_targets if tid not in keep_ids]
  188. for tid in remove_ids:
  189. self.active_targets.pop(tid, None)
  190. if tid in self.target_order:
  191. self.target_order.remove(tid)
  192. self._capture_counts.pop(tid, None)
  193. def _center_distance(self, center: Tuple[int, int], frame_size: Tuple[int, int]) -> float:
  194. cx, cy = frame_size[0] / 2, frame_size[1] / 2
  195. dx = abs(center[0] - cx) / cx
  196. dy = abs(center[1] - cy) / cy
  197. return (dx + dy) / 2
  198. def _ptz_worker(self):
  199. while self.running:
  200. try:
  201. if self._paused:
  202. self._paused_event.wait()
  203. continue
  204. with self.targets_lock:
  205. has_targets = bool(self.active_targets)
  206. target_order_snapshot = self.target_order.copy()
  207. if not has_targets or not target_order_snapshot:
  208. self._flush_batch_if_needed()
  209. time.sleep(0.1)
  210. continue
  211. if self.current_index >= len(target_order_snapshot):
  212. self.current_index = 0
  213. target_id = target_order_snapshot[self.current_index]
  214. with self.targets_lock:
  215. target = self.active_targets.get(target_id)
  216. if target is None or target.lost:
  217. self._advance(len(target_order_snapshot))
  218. continue
  219. record = self._capture_one(target)
  220. if record:
  221. self.batch_captures.append(record)
  222. self._update_stats("captures")
  223. self._advance(len(target_order_snapshot))
  224. # 一轮完成
  225. if self.current_index == 0 and self.batch_captures:
  226. self._upload_batch(self.batch_captures)
  227. self.batch_captures.clear()
  228. except Exception as e:
  229. logger.error(f"PTZ 线程错误: {e}")
  230. time.sleep(0.1)
  231. def _advance(self, order_len: int = None):
  232. if order_len is None:
  233. order_len = len(self.target_order) or 1
  234. self.current_index = (self.current_index + 1) % order_len
  235. def _capture_one(self, target: TrackedPerson) -> Optional[CaptureRecord]:
  236. frame = self.panorama.get_frame()
  237. if frame is None:
  238. return None
  239. frame_h, frame_w = frame.shape[:2]
  240. x_ratio = target.center[0] / frame_w
  241. y_ratio = target.center[1] / frame_h
  242. if self.calibrator and self.calibrator.is_calibrated():
  243. pan, tilt = self.calibrator.transform(x_ratio, y_ratio)
  244. ptz_config = getattr(self.ptz, "ptz_config", {})
  245. if ptz_config.get("pan_flip"):
  246. pan = (pan + 180) % 360
  247. zoom = ptz_config.get("default_zoom", 8)
  248. else:
  249. pan, tilt, zoom = self.ptz.calculate_ptz_position(x_ratio, y_ratio)
  250. success = self.ptz.goto_exact_position(pan, tilt, zoom)
  251. if not success:
  252. return None
  253. time.sleep(self.config.get("ptz_stabilize_time", 2.0))
  254. ptz_frame = self._get_clear_ptz_frame()
  255. if ptz_frame is None:
  256. return None
  257. max_cap = self.config.get("max_capture_per_target", 0)
  258. if max_cap > 0 and self._capture_counts.get(target.track_id, 0) >= max_cap:
  259. return None
  260. self._capture_counts[target.track_id] = self._capture_counts.get(target.track_id, 0) + 1
  261. panorama_image = frame.copy() if self.config.get("save_panorama_pair", True) else None
  262. # 本地保存
  263. self._save_local(ptz_frame, panorama_image, target, pan, tilt, zoom)
  264. return CaptureRecord(
  265. track_id=target.track_id,
  266. timestamp=time.time(),
  267. position=(x_ratio, y_ratio),
  268. ptz_position=(pan, tilt, zoom),
  269. ptz_image=ptz_frame,
  270. panorama_image=panorama_image,
  271. confidence=target.confidence,
  272. )
  273. def _get_clear_ptz_frame(self, max_attempts: int = 5, wait_interval: float = 0.2) -> Optional[np.ndarray]:
  274. best_frame = None
  275. best_score = -1
  276. for _ in range(max_attempts):
  277. frame = self.ptz.get_frame()
  278. if frame is not None:
  279. frame_copy = frame.copy()
  280. gray = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2GRAY)
  281. score = cv2.Laplacian(gray, cv2.CV_64F).var()
  282. if score > best_score:
  283. best_score = score
  284. best_frame = frame_copy
  285. time.sleep(wait_interval)
  286. return best_frame
  287. def _save_local(self, ptz_frame, panorama_image, target, pan, tilt, zoom):
  288. capture_dir = self.config.get("capture_dir", "/home/admin/dsh/tracking_captures")
  289. try:
  290. os.makedirs(capture_dir, exist_ok=True)
  291. except OSError as e:
  292. logger.warning(f"无法创建抓拍目录 {capture_dir}: {e}")
  293. return
  294. timestamp = int(time.time() * 1000)
  295. base = f"{capture_dir}/ptz_{target.track_id}_{timestamp}_{pan:.0f}_{tilt:.0f}_z{zoom}.jpg"
  296. cv2.imwrite(base, ptz_frame)
  297. if panorama_image is not None:
  298. pan_base = f"{capture_dir}/panorama_{target.track_id}_{timestamp}.jpg"
  299. cv2.imwrite(pan_base, panorama_image)
  300. def _upload_batch(self, records: List[CaptureRecord]):
  301. if not self.event_pusher or not self.config.get("enable_upload", True):
  302. return
  303. try:
  304. uploads = []
  305. for r in records:
  306. ptz_url = self.event_pusher.upload_numpy_image(r.ptz_image)
  307. pan_url = None
  308. if r.panorama_image is not None:
  309. pan_url = self.event_pusher.upload_numpy_image(r.panorama_image)
  310. uploads.append({
  311. "track_id": r.track_id,
  312. "ptz_image_url": ptz_url,
  313. "panorama_image_url": pan_url,
  314. "position": r.position,
  315. "ptz_position": r.ptz_position,
  316. "confidence": r.confidence,
  317. "timestamp": r.timestamp,
  318. })
  319. self.event_pusher.push_tracking_capture(batch_time=time.time(), captures=uploads)
  320. self._update_stats("uploads")
  321. except Exception as e:
  322. logger.error(f"批量上传失败: {e}")
  323. def _flush_batch_if_needed(self):
  324. if self.batch_captures:
  325. self._upload_batch(self.batch_captures)
  326. self.batch_captures.clear()
  327. def _update_stats(self, key: str, value: int = 1):
  328. with self.stats_lock:
  329. if key in self.stats:
  330. self.stats[key] += value
  331. def get_stats(self) -> Dict:
  332. with self.stats_lock:
  333. return self.stats.copy()