test_polling_tracker.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. import sys
  2. import os
  3. import time
  4. import threading
  5. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  6. import numpy as np
  7. import pytest
  8. from polling_tracker import PollingTrackingCoordinator, CaptureRecord
  9. from tracker import TrackedPerson
  10. class _FakePosition:
  11. def __init__(self):
  12. self.pan = 0
  13. self.tilt = 0
  14. self.zoom = 1
  15. class FakePanorama:
  16. def __init__(self):
  17. self.frame = np.zeros((480, 640, 3), dtype=np.uint8)
  18. self.connected = False
  19. self.streaming = False
  20. self.stopped = False
  21. def connect(self):
  22. self.connected = True
  23. return True
  24. def start_stream_rtsp(self):
  25. self.streaming = True
  26. return True
  27. def stop_stream_rtsp(self):
  28. self.stopped = True
  29. return True
  30. def disconnect(self):
  31. self.connected = False
  32. self.streaming = False
  33. return True
  34. def get_frame(self):
  35. return self.frame.copy()
  36. class FakePTZ:
  37. def __init__(self):
  38. self.commands = []
  39. self.connected = False
  40. self.ptz_frame = np.zeros((100, 100, 3), dtype=np.uint8)
  41. self.current_position = _FakePosition()
  42. self.ptz_config = {"default_zoom": 8}
  43. def connect(self):
  44. self.connected = True
  45. return True
  46. def disconnect(self):
  47. self.connected = False
  48. return True
  49. def goto_exact_position(self, pan, tilt, zoom):
  50. self.commands.append((pan, tilt, zoom))
  51. return True
  52. def get_current_position(self):
  53. return self.current_position
  54. def calculate_ptz_position(self, x, y, zoom=None):
  55. return x * 180, y * 90, zoom or 8
  56. def get_frame(self):
  57. return self.ptz_frame.copy()
  58. class FakeTracker:
  59. def __init__(self, persons):
  60. self.persons = persons
  61. self.released = False
  62. def update(self, frame):
  63. return self.persons
  64. def release(self):
  65. self.released = True
  66. class FakeEventPusher:
  67. def __init__(self):
  68. self.uploads = []
  69. self.pushes = []
  70. def upload_numpy_image(self, image):
  71. url = f"url_{id(image)}"
  72. self.uploads.append(url)
  73. return url
  74. def push_tracking_capture(self, batch_time, captures):
  75. self.pushes.append({"batch_time": batch_time, "captures": captures})
  76. def test_update_active_targets():
  77. pan = FakePanorama()
  78. ptz = FakePTZ()
  79. tracker = FakeTracker([
  80. TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(20, 30), confidence=0.9),
  81. TrackedPerson(track_id=2, bbox=(50, 60, 70, 80), center=(60, 70), confidence=0.8),
  82. ])
  83. coord = PollingTrackingCoordinator(pan, ptz, tracker, config={"max_tracking_targets": 4})
  84. frame = pan.get_frame()
  85. coord._update_active_targets(tracker.update(frame), frame.shape)
  86. assert len(coord.active_targets) == 2
  87. assert 1 in coord.target_order
  88. assert 2 in coord.target_order
  89. def test_advance_loop():
  90. coord = PollingTrackingCoordinator.__new__(PollingTrackingCoordinator)
  91. coord.target_order = [1, 2, 3]
  92. coord.current_index = 0
  93. coord._advance()
  94. assert coord.current_index == 1
  95. coord.current_index = 2
  96. coord._advance()
  97. assert coord.current_index == 0
  98. def test_capture_record_creation():
  99. record = CaptureRecord(
  100. track_id=1,
  101. timestamp=1.0,
  102. position=(0.5, 0.5),
  103. ptz_position=(90.0, 45.0, 8),
  104. ptz_image=np.zeros((100, 100, 3), dtype=np.uint8),
  105. panorama_image=None,
  106. confidence=0.9,
  107. )
  108. assert record.track_id == 1
  109. def test_target_lost_and_timeout_removal():
  110. pan = FakePanorama()
  111. ptz = FakePTZ()
  112. tracker = FakeTracker([
  113. TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(20, 30), confidence=0.9),
  114. ])
  115. coord = PollingTrackingCoordinator(
  116. pan, ptz, tracker, config={"tracking_timeout": 0.2, "max_tracking_targets": 4}
  117. )
  118. frame = pan.get_frame()
  119. coord._update_active_targets(tracker.update(frame), frame.shape)
  120. assert 1 in coord.target_order
  121. assert not coord.active_targets[1].lost
  122. # 当前帧无目标:标记为丢失,但仍在 target_order 中
  123. tracker.persons = []
  124. coord._update_active_targets(tracker.update(frame), frame.shape)
  125. assert 1 in coord.target_order
  126. assert coord.active_targets[1].lost
  127. # 超时后应被移除
  128. time.sleep(0.25)
  129. coord._update_active_targets(tracker.update(frame), frame.shape)
  130. assert 1 not in coord.target_order
  131. assert 1 not in coord.active_targets
  132. def test_target_lost_not_removed_before_timeout():
  133. pan = FakePanorama()
  134. ptz = FakePTZ()
  135. tracker = FakeTracker([
  136. TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(20, 30), confidence=0.9),
  137. ])
  138. coord = PollingTrackingCoordinator(
  139. pan, ptz, tracker, config={"tracking_timeout": 1.0, "max_tracking_targets": 4}
  140. )
  141. frame = pan.get_frame()
  142. coord._update_active_targets(tracker.update(frame), frame.shape)
  143. tracker.persons = []
  144. coord._update_active_targets(tracker.update(frame), frame.shape)
  145. assert 1 in coord.target_order
  146. assert coord.active_targets[1].lost
  147. def test_batch_upload_flush():
  148. pan = FakePanorama()
  149. ptz = FakePTZ()
  150. tracker = FakeTracker([
  151. TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(320, 240), confidence=0.9),
  152. ])
  153. coord = PollingTrackingCoordinator(
  154. pan, ptz, tracker,
  155. config={"ptz_stabilize_time": 0.01, "ptz_command_cooldown": 0.0, "enable_upload": True}
  156. )
  157. pusher = FakeEventPusher()
  158. coord.set_event_pusher(pusher)
  159. frame = pan.get_frame()
  160. coord._update_active_targets(tracker.update(frame), frame.shape)
  161. record = coord._capture_one(list(coord.active_targets.values())[0])
  162. assert record is not None
  163. coord.batch_captures.append(record)
  164. coord._flush_batch_if_needed()
  165. assert len(coord.batch_captures) == 0
  166. assert len(pusher.pushes) == 1
  167. assert len(pusher.uploads) == 2 # PTZ + panorama
  168. def test_pause_resume():
  169. pan = FakePanorama()
  170. ptz = FakePTZ()
  171. tracker = FakeTracker([])
  172. coord = PollingTrackingCoordinator(pan, ptz, tracker, config={})
  173. coord.pause()
  174. assert coord._paused is True
  175. assert coord._paused_event.is_set() is False
  176. coord.resume()
  177. assert coord._paused is False
  178. assert coord._paused_event.is_set() is True
  179. def test_thread_start_stop_lifecycle():
  180. pan = FakePanorama()
  181. ptz = FakePTZ()
  182. tracker = FakeTracker([])
  183. coord = PollingTrackingCoordinator(
  184. pan, ptz, tracker,
  185. config={"ptz_stabilize_time": 0.01, "ptz_command_cooldown": 0.0}
  186. )
  187. assert coord.start() is True
  188. assert coord.running is True
  189. assert pan.streaming is True
  190. assert ptz.connected is True
  191. time.sleep(0.05)
  192. coord.stop()
  193. assert coord.running is False
  194. assert pan.stopped is True
  195. assert tracker.released is True
  196. def test_ptz_worker_capture_flow():
  197. pan = FakePanorama()
  198. ptz = FakePTZ()
  199. tracker = FakeTracker([
  200. TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(320, 240), confidence=0.9),
  201. ])
  202. coord = PollingTrackingCoordinator(
  203. pan, ptz, tracker,
  204. config={
  205. "ptz_stabilize_time": 0.01,
  206. "ptz_command_cooldown": 0.0,
  207. "max_capture_per_target": 1,
  208. }
  209. )
  210. frame = pan.get_frame()
  211. coord._update_active_targets(tracker.update(frame), frame.shape)
  212. target = coord.active_targets[1]
  213. record = coord._capture_one(target)
  214. assert record is not None
  215. assert record.track_id == 1
  216. assert len(ptz.commands) == 1
  217. with coord._capture_counts_lock:
  218. assert coord._capture_counts[1] == 1
  219. # 超过最大抓拍数后应返回 None
  220. record2 = coord._capture_one(target)
  221. assert record2 is None