test_polling_tracker.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import sys
  2. import os
  3. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  4. import numpy as np
  5. import pytest
  6. from polling_tracker import PollingTrackingCoordinator, CaptureRecord
  7. from tracker import TrackedPerson
  8. class FakePanorama:
  9. def __init__(self):
  10. self.frame = np.zeros((480, 640, 3), dtype=np.uint8)
  11. def get_frame(self):
  12. return self.frame.copy()
  13. class FakePTZ:
  14. def __init__(self):
  15. self.commands = []
  16. self.current_position = type("P", (), {"pan": 0, "tilt": 0, "zoom": 1})()
  17. def goto_exact_position(self, pan, tilt, zoom):
  18. self.commands.append((pan, tilt, zoom))
  19. return True
  20. def get_current_position(self):
  21. return self.current_position
  22. def calculate_ptz_position(self, x, y, zoom=None):
  23. return x * 180, y * 90, zoom or 8
  24. class FakeTracker:
  25. def __init__(self, persons):
  26. self.persons = persons
  27. def update(self, frame):
  28. return self.persons
  29. def test_update_active_targets():
  30. pan = FakePanorama()
  31. ptz = FakePTZ()
  32. tracker = FakeTracker([
  33. TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(20, 30), confidence=0.9),
  34. TrackedPerson(track_id=2, bbox=(50, 60, 70, 80), center=(60, 70), confidence=0.8),
  35. ])
  36. coord = PollingTrackingCoordinator(pan, ptz, tracker, config={"max_tracking_targets": 4})
  37. frame = pan.get_frame()
  38. coord._update_active_targets(tracker.update(frame), frame.shape)
  39. assert len(coord.active_targets) == 2
  40. assert 1 in coord.target_order
  41. assert 2 in coord.target_order
  42. def test_advance_loop():
  43. coord = PollingTrackingCoordinator.__new__(PollingTrackingCoordinator)
  44. coord.target_order = [1, 2, 3]
  45. coord.current_index = 0
  46. coord._advance()
  47. assert coord.current_index == 1
  48. coord.current_index = 2
  49. coord._advance()
  50. assert coord.current_index == 0
  51. def test_capture_record_creation():
  52. record = CaptureRecord(
  53. track_id=1,
  54. timestamp=1.0,
  55. position=(0.5, 0.5),
  56. ptz_position=(90.0, 45.0, 8),
  57. ptz_image=np.zeros((100, 100, 3), dtype=np.uint8),
  58. panorama_image=None,
  59. confidence=0.9,
  60. )
  61. assert record.track_id == 1