test_polling_tracker.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. import sys
  2. import os
  3. import time
  4. import threading
  5. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  6. import numpy as np
  7. import pytest
  8. from polling_tracker import PollingTrackingCoordinator, CaptureRecord
  9. from tracker import TrackedPerson
  10. class _FakePosition:
  11. def __init__(self):
  12. self.pan = 0
  13. self.tilt = 0
  14. self.zoom = 1
  15. class FakePanorama:
  16. def __init__(self):
  17. self.frame = np.zeros((480, 640, 3), dtype=np.uint8)
  18. self.connected = False
  19. self.streaming = False
  20. self.stopped = False
  21. def connect(self):
  22. self.connected = True
  23. return True
  24. def start_stream_rtsp(self):
  25. self.streaming = True
  26. return True
  27. def stop_stream_rtsp(self):
  28. self.stopped = True
  29. return True
  30. def disconnect(self):
  31. self.connected = False
  32. self.streaming = False
  33. return True
  34. def get_frame(self):
  35. return self.frame.copy()
  36. class FakePTZ:
  37. def __init__(self):
  38. self.commands = []
  39. self.connected = False
  40. self.ptz_frame = np.zeros((100, 100, 3), dtype=np.uint8)
  41. self.current_position = _FakePosition()
  42. self.ptz_config = {"default_zoom": 8}
  43. def connect(self):
  44. self.connected = True
  45. return True
  46. def disconnect(self):
  47. self.connected = False
  48. return True
  49. def goto_exact_position(self, pan, tilt, zoom):
  50. self.commands.append((pan, tilt, zoom))
  51. return True
  52. def get_current_position(self):
  53. return self.current_position
  54. def calculate_ptz_position(self, x, y, zoom=None):
  55. return x * 180, y * 90, zoom or 8
  56. def get_frame(self):
  57. return self.ptz_frame.copy()
  58. class FakeTracker:
  59. def __init__(self, persons):
  60. self.persons = persons
  61. self.released = False
  62. def update(self, frame):
  63. return self.persons
  64. def release(self):
  65. self.released = True
  66. class FakeCalibrator:
  67. """模拟已校准的标定器,返回固定的 PTZ 角度。"""
  68. def __init__(self, pan=180.0, tilt=45.0):
  69. self._pan = pan
  70. self._tilt = tilt
  71. def transform(self, x_ratio, y_ratio):
  72. return (self._pan, self._tilt)
  73. def is_calibrated(self):
  74. return True
  75. class FakeEventPusher:
  76. def __init__(self):
  77. self.uploads = []
  78. self.pushes = []
  79. def upload_numpy_image(self, image):
  80. url = f"url_{id(image)}"
  81. self.uploads.append(url)
  82. return url
  83. def push_tracking_capture(self, batch_time, captures):
  84. self.pushes.append({"batch_time": batch_time, "captures": captures})
  85. def test_update_active_targets():
  86. pan = FakePanorama()
  87. ptz = FakePTZ()
  88. tracker = FakeTracker([
  89. TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(20, 30), confidence=0.9),
  90. TrackedPerson(track_id=2, bbox=(50, 60, 70, 80), center=(60, 70), confidence=0.8),
  91. ])
  92. coord = PollingTrackingCoordinator(pan, ptz, tracker, config={"max_tracking_targets": 4})
  93. frame = pan.get_frame()
  94. coord._update_active_targets(tracker.update(frame), frame.shape)
  95. assert len(coord.active_targets) == 2
  96. assert 1 in coord.target_order
  97. assert 2 in coord.target_order
  98. def test_advance_loop():
  99. coord = PollingTrackingCoordinator.__new__(PollingTrackingCoordinator)
  100. coord.target_order = [1, 2, 3]
  101. coord.current_index = 0
  102. coord._advance()
  103. assert coord.current_index == 1
  104. coord.current_index = 2
  105. coord._advance()
  106. assert coord.current_index == 0
  107. def test_capture_record_creation():
  108. record = CaptureRecord(
  109. track_id=1,
  110. timestamp=1.0,
  111. position=(0.5, 0.5),
  112. ptz_position=(90.0, 45.0, 8),
  113. ptz_image=np.zeros((100, 100, 3), dtype=np.uint8),
  114. panorama_image=None,
  115. confidence=0.9,
  116. )
  117. assert record.track_id == 1
  118. def test_target_lost_and_timeout_removal():
  119. pan = FakePanorama()
  120. ptz = FakePTZ()
  121. tracker = FakeTracker([
  122. TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(20, 30), confidence=0.9),
  123. ])
  124. coord = PollingTrackingCoordinator(
  125. pan, ptz, tracker, config={"tracking_timeout": 0.2, "max_tracking_targets": 4}
  126. )
  127. frame = pan.get_frame()
  128. coord._update_active_targets(tracker.update(frame), frame.shape)
  129. assert 1 in coord.target_order
  130. assert not coord.active_targets[1].lost
  131. # 当前帧无目标:标记为丢失,但仍在 target_order 中
  132. tracker.persons = []
  133. coord._update_active_targets(tracker.update(frame), frame.shape)
  134. assert 1 in coord.target_order
  135. assert coord.active_targets[1].lost
  136. # 超时后应被移除
  137. time.sleep(0.25)
  138. coord._update_active_targets(tracker.update(frame), frame.shape)
  139. assert 1 not in coord.target_order
  140. assert 1 not in coord.active_targets
  141. def test_target_lost_not_removed_before_timeout():
  142. pan = FakePanorama()
  143. ptz = FakePTZ()
  144. tracker = FakeTracker([
  145. TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(20, 30), confidence=0.9),
  146. ])
  147. coord = PollingTrackingCoordinator(
  148. pan, ptz, tracker, config={"tracking_timeout": 1.0, "max_tracking_targets": 4}
  149. )
  150. frame = pan.get_frame()
  151. coord._update_active_targets(tracker.update(frame), frame.shape)
  152. tracker.persons = []
  153. coord._update_active_targets(tracker.update(frame), frame.shape)
  154. assert 1 in coord.target_order
  155. assert coord.active_targets[1].lost
  156. def test_batch_upload_flush():
  157. pan = FakePanorama()
  158. ptz = FakePTZ()
  159. tracker = FakeTracker([
  160. TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(320, 240), confidence=0.9),
  161. ])
  162. coord = PollingTrackingCoordinator(
  163. pan, ptz, tracker,
  164. config={"ptz_stabilize_time": 0.01, "ptz_command_cooldown": 0.0, "enable_upload": True}
  165. )
  166. pusher = FakeEventPusher()
  167. coord.set_event_pusher(pusher)
  168. frame = pan.get_frame()
  169. coord._update_active_targets(tracker.update(frame), frame.shape)
  170. record = coord._capture_one(list(coord.active_targets.values())[0])
  171. assert record is not None
  172. coord.batch_captures.append(record)
  173. coord._flush_batch_if_needed()
  174. assert len(coord.batch_captures) == 0
  175. assert len(pusher.pushes) == 1
  176. assert len(pusher.uploads) == 2 # PTZ + panorama
  177. def test_pause_resume():
  178. pan = FakePanorama()
  179. ptz = FakePTZ()
  180. tracker = FakeTracker([])
  181. coord = PollingTrackingCoordinator(pan, ptz, tracker, config={})
  182. coord.pause()
  183. assert coord._paused is True
  184. assert coord._paused_event.is_set() is False
  185. coord.resume()
  186. assert coord._paused is False
  187. assert coord._paused_event.is_set() is True
  188. def test_thread_start_stop_lifecycle():
  189. pan = FakePanorama()
  190. ptz = FakePTZ()
  191. tracker = FakeTracker([])
  192. coord = PollingTrackingCoordinator(
  193. pan, ptz, tracker,
  194. config={"ptz_stabilize_time": 0.01, "ptz_command_cooldown": 0.0}
  195. )
  196. assert coord.start() is True
  197. assert coord.running is True
  198. assert pan.streaming is True
  199. assert ptz.connected is True
  200. time.sleep(0.05)
  201. coord.stop()
  202. assert coord.running is False
  203. assert pan.stopped is True
  204. assert tracker.released is True
  205. def test_ptz_worker_capture_flow():
  206. pan = FakePanorama()
  207. ptz = FakePTZ()
  208. tracker = FakeTracker([
  209. TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(320, 240), confidence=0.9),
  210. ])
  211. coord = PollingTrackingCoordinator(
  212. pan, ptz, tracker,
  213. config={
  214. "ptz_stabilize_time": 0.01,
  215. "ptz_command_cooldown": 0.0,
  216. "max_capture_per_target": 1,
  217. }
  218. )
  219. frame = pan.get_frame()
  220. coord._update_active_targets(tracker.update(frame), frame.shape)
  221. target = coord.active_targets[1]
  222. record = coord._capture_one(target)
  223. assert record is not None
  224. assert record.track_id == 1
  225. assert len(ptz.commands) == 1
  226. with coord._capture_counts_lock:
  227. assert coord._capture_counts[1] == 1
  228. # 超过最大抓拍数后应返回 None
  229. record2 = coord._capture_one(target)
  230. assert record2 is None
  231. def test_capture_with_pan_flip_does_not_double_flip():
  232. """
  233. 验证:当存在校准器且 ptz_config 启用 pan_flip 时,
  234. 不应在校准结果上再次应用 pan_flip。
  235. 场景:球机实际朝向与枪机相反(pan_flip=True)。
  236. 校准器通过真实扫描得到全景中心对应的球机角度为 180°,
  237. 该角度应直接发送给球机。若再次翻转,球机会被转到背面(0°)。
  238. """
  239. pan = FakePanorama()
  240. ptz = FakePTZ()
  241. ptz.ptz_config["pan_flip"] = True
  242. tracker = FakeTracker([
  243. TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(320, 240), confidence=0.9),
  244. ])
  245. calibrator = FakeCalibrator(pan=180.0, tilt=45.0)
  246. coord = PollingTrackingCoordinator(
  247. pan, ptz, tracker,
  248. config={"ptz_stabilize_time": 0.01, "ptz_command_cooldown": 0.0},
  249. calibrator=calibrator,
  250. )
  251. frame = pan.get_frame()
  252. coord._update_active_targets(tracker.update(frame), frame.shape)
  253. target = coord.active_targets[1]
  254. record = coord._capture_one(target)
  255. assert record is not None
  256. assert len(ptz.commands) == 1
  257. sent_pan, sent_tilt, sent_zoom = ptz.commands[0]
  258. # 校准器返回的 180° 应直接发送,不能因 pan_flip 而被翻为 0°
  259. assert sent_pan == 180.0
  260. assert sent_tilt == 45.0