Bläddra i källkod

fix(polling_tracker): implement tracking timeout, thread safety, cleanup, tests

wenhongquan 1 dag sedan
förälder
incheckning
8fd4599a48
2 ändrade filer med 316 tillägg och 49 borttagningar
  1. 111 48
      dual_camera_system/polling_tracker.py
  2. 205 1
      dual_camera_system/tests/test_polling_tracker.py

+ 111 - 48
dual_camera_system/polling_tracker.py

@@ -14,7 +14,7 @@ import numpy as np
 
 from config import TRACKING_CONFIG
 from tracker import UltralyticsTracker, TrackedPerson
-from coordinator import TargetSelector
+from coordinator import TargetSelector, TrackingTarget
 
 logger = logging.getLogger(__name__)
 
@@ -53,8 +53,11 @@ class PollingTrackingCoordinator:
         self.current_index: int = 0
         self.batch_captures: List[CaptureRecord] = []
         self._capture_counts: Dict[int, int] = {}
+        self._last_seen_time: Dict[int, float] = {}
 
         self.targets_lock = threading.Lock()
+        self.batch_lock = threading.Lock()
+        self._capture_counts_lock = threading.Lock()
         self.running = False
         self._detection_thread = None
         self._ptz_thread = None
@@ -62,6 +65,9 @@ class PollingTrackingCoordinator:
         self._paused_event = threading.Event()
         self._paused_event.set()
 
+        self._last_ptz_command_time = 0.0
+        self._last_ptz_command_time_lock = threading.Lock()
+
         self.target_selector = TargetSelector(self.config.get("target_selection", {}))
         self.event_pusher = None
 
@@ -88,14 +94,14 @@ class PollingTrackingCoordinator:
 
     def start(self) -> bool:
         if not self.panorama.connect():
-            print("连接全景摄像头失败")
+            logger.error("连接全景摄像头失败")
             return False
         if not self.ptz.connect():
-            print("连接球机失败")
+            logger.error("连接球机失败")
             self.panorama.disconnect()
             return False
         if not self.panorama.start_stream_rtsp():
-            print("启动全景视频流失败")
+            logger.error("启动全景视频流失败")
             self.panorama.disconnect()
             self.ptz.disconnect()
             return False
@@ -109,18 +115,38 @@ class PollingTrackingCoordinator:
         with self.stats_lock:
             self.stats["start_time"] = time.time()
 
-        print("轮询跟踪抓拍协调器已启动")
+        logger.info("轮询跟踪抓拍协调器已启动")
         return True
 
     def stop(self):
         self.running = False
+        self._paused_event.set()
         if self._detection_thread:
             self._detection_thread.join(timeout=3)
         if self._ptz_thread:
             self._ptz_thread.join(timeout=3)
+
+        # 刷新待上传批次
+        self._flush_batch_if_needed()
+
+        # 停止视频流后再断开连接
+        if hasattr(self.panorama, "stop_stream_rtsp"):
+            try:
+                self.panorama.stop_stream_rtsp()
+            except Exception as e:
+                logger.warning(f"停止全景视频流失败: {e}")
+
         self.panorama.disconnect()
         self.ptz.disconnect()
-        print("轮询跟踪抓拍协调器已停止")
+
+        # 释放跟踪器资源
+        if self.tracker is not None and hasattr(self.tracker, "release"):
+            try:
+                self.tracker.release()
+            except Exception as e:
+                logger.warning(f"释放跟踪器失败: {e}")
+
+        logger.info("轮询跟踪抓拍协调器已停止")
 
     def pause(self):
         self._paused = True
@@ -138,6 +164,11 @@ class PollingTrackingCoordinator:
 
         while self.running:
             try:
+                # 暂停时阻塞等待,避免忙等
+                if self._paused:
+                    self._paused_event.wait()
+                    continue
+
                 frame = self.panorama.get_frame()
                 if frame is None:
                     time.sleep(0.01)
@@ -146,7 +177,7 @@ class PollingTrackingCoordinator:
                 self._update_stats("frames_processed")
 
                 current_time = time.time()
-                if not self._paused and current_time - last_detection_time >= detection_interval:
+                if current_time - last_detection_time >= detection_interval:
                     last_detection_time = current_time
                     tracked = self.tracker.update(frame)
                     self._update_active_targets(tracked, frame.shape)
@@ -171,33 +202,38 @@ class PollingTrackingCoordinator:
                 if p.track_id < 0:
                     continue
                 updated_ids.add(p.track_id)
+                p.lost = False
                 self.active_targets[p.track_id] = p
+                self._last_seen_time[p.track_id] = current_time
                 if p.track_id not in self.target_order:
                     self.target_order.append(p.track_id)
 
             # 标记丢失
-            lost_ids = [tid for tid in self.target_order if tid not in updated_ids]
-            for tid in lost_ids:
-                t = self.active_targets.get(tid)
-                if t is not None:
-                    t.lost = True
+            for tid in self.target_order:
+                if tid not in updated_ids:
+                    t = self.active_targets.get(tid)
+                    if t is not None:
+                        t.lost = True
 
-            # 移除长期丢失
+            # 移除长期丢失(超过 tracking_timeout)
             remove_ids = []
             for tid in self.target_order:
                 t = self.active_targets.get(tid)
                 if t is None:
                     remove_ids.append(tid)
                     continue
-                # 简单丢失超时移除
                 if t.lost:
-                    remove_ids.append(tid)
+                    last_seen = self._last_seen_time.get(tid, current_time)
+                    if current_time - last_seen >= timeout:
+                        remove_ids.append(tid)
 
             for tid in remove_ids:
                 if tid in self.target_order:
                     self.target_order.remove(tid)
                 self.active_targets.pop(tid, None)
-                self._capture_counts.pop(tid, None)
+                self._last_seen_time.pop(tid, None)
+                with self._capture_counts_lock:
+                    self._capture_counts.pop(tid, None)
 
             # 人数上限淘汰
             if len(self.active_targets) > max_targets:
@@ -208,23 +244,28 @@ class PollingTrackingCoordinator:
         frame_size = (frame_w, frame_h)
         scored = []
         for t in targets:
-            target_wrapper = type("T", (), {
-                "track_id": t.track_id,
-                "area": (t.bbox[2] - t.bbox[0]) * (t.bbox[3] - t.bbox[1]),
-                "confidence": t.confidence,
-                "center_distance": self._center_distance(t.center, frame_size),
-                "score": 0.0,
-            })()
-            target_wrapper.score = self.target_selector.calculate_score(target_wrapper, frame_size)
-            scored.append(target_wrapper)
+            area = (t.bbox[2] - t.bbox[0]) * (t.bbox[3] - t.bbox[1])
+            center_distance = self._center_distance(t.center, frame_size)
+            target = TrackingTarget(
+                track_id=t.track_id,
+                position=(t.center[0] / frame_w, t.center[1] / frame_h),
+                last_update=time.time(),
+                area=area,
+                confidence=t.confidence,
+                center_distance=center_distance,
+            )
+            target.score = self.target_selector.calculate_score(target, frame_size)
+            scored.append(target)
         scored.sort(key=lambda x: x.score, reverse=True)
         keep_ids = {t.track_id for t in scored[:max_targets]}
         remove_ids = [tid for tid in self.active_targets if tid not in keep_ids]
         for tid in remove_ids:
             self.active_targets.pop(tid, None)
+            self._last_seen_time.pop(tid, None)
             if tid in self.target_order:
                 self.target_order.remove(tid)
-            self._capture_counts.pop(tid, None)
+            with self._capture_counts_lock:
+                self._capture_counts.pop(tid, None)
 
     def _center_distance(self, center: Tuple[int, int], frame_size: Tuple[int, int]) -> float:
         cx, cy = frame_size[0] / 2, frame_size[1] / 2
@@ -235,42 +276,50 @@ class PollingTrackingCoordinator:
     def _ptz_worker(self):
         while self.running:
             try:
+                # 暂停时阻塞等待恢复
                 if self._paused:
                     self._paused_event.wait()
                     continue
 
+                # 原子性获取目标快照和当前目标
                 with self.targets_lock:
-                    has_targets = bool(self.active_targets)
                     target_order_snapshot = self.target_order.copy()
+                    has_targets = bool(self.active_targets)
+                    if target_order_snapshot:
+                        if self.current_index >= len(target_order_snapshot):
+                            self.current_index = 0
+                        target_id = target_order_snapshot[self.current_index]
+                        target = self.active_targets.get(target_id)
+                    else:
+                        target_id = None
+                        target = None
 
                 if not has_targets or not target_order_snapshot:
                     self._flush_batch_if_needed()
                     time.sleep(0.1)
                     continue
 
-                if self.current_index >= len(target_order_snapshot):
-                    self.current_index = 0
-
-                target_id = target_order_snapshot[self.current_index]
-
-                with self.targets_lock:
-                    target = self.active_targets.get(target_id)
-
                 if target is None or target.lost:
-                    self._advance(len(target_order_snapshot))
+                    # 目标丢失时跳过但保留在队列中,并短暂休眠避免忙等
+                    time.sleep(0.01)
+                    with self.targets_lock:
+                        self._advance(len(target_order_snapshot))
                     continue
 
                 record = self._capture_one(target)
                 if record:
-                    self.batch_captures.append(record)
+                    with self.batch_lock:
+                        self.batch_captures.append(record)
                     self._update_stats("captures")
 
-                self._advance(len(target_order_snapshot))
+                with self.targets_lock:
+                    self._advance(len(target_order_snapshot))
 
                 # 一轮完成
-                if self.current_index == 0 and self.batch_captures:
-                    self._upload_batch(self.batch_captures)
-                    self.batch_captures.clear()
+                with self.batch_lock:
+                    if self.current_index == 0 and self.batch_captures:
+                        self._upload_batch(self.batch_captures)
+                        self.batch_captures.clear()
 
             except Exception as e:
                 logger.error(f"PTZ 线程错误: {e}")
@@ -299,19 +348,32 @@ class PollingTrackingCoordinator:
         else:
             pan, tilt, zoom = self.ptz.calculate_ptz_position(x_ratio, y_ratio)
 
+        #  enforce PTZ command cooldown
+        ptz_command_cooldown = self.config.get("ptz_command_cooldown", 0.2)
+        with self._last_ptz_command_time_lock:
+            elapsed = time.time() - self._last_ptz_command_time
+            if elapsed < ptz_command_cooldown:
+                time.sleep(ptz_command_cooldown - elapsed)
+
         success = self.ptz.goto_exact_position(pan, tilt, zoom)
         if not success:
             return None
 
-        time.sleep(self.config.get("ptz_stabilize_time", 2.0))
+        with self._last_ptz_command_time_lock:
+            self._last_ptz_command_time = time.time()
+
+        ptz_stabilize_time = self.config.get("ptz_stabilize_time", 2.0)
+        time.sleep(max(ptz_stabilize_time, ptz_command_cooldown))
+
         ptz_frame = self._get_clear_ptz_frame()
         if ptz_frame is None:
             return None
 
         max_cap = self.config.get("max_capture_per_target", 0)
-        if max_cap > 0 and self._capture_counts.get(target.track_id, 0) >= max_cap:
-            return None
-        self._capture_counts[target.track_id] = self._capture_counts.get(target.track_id, 0) + 1
+        with self._capture_counts_lock:
+            if max_cap > 0 and self._capture_counts.get(target.track_id, 0) >= max_cap:
+                return None
+            self._capture_counts[target.track_id] = self._capture_counts.get(target.track_id, 0) + 1
 
         panorama_image = frame.copy() if self.config.get("save_panorama_pair", True) else None
 
@@ -382,9 +444,10 @@ class PollingTrackingCoordinator:
             logger.error(f"批量上传失败: {e}")
 
     def _flush_batch_if_needed(self):
-        if self.batch_captures:
-            self._upload_batch(self.batch_captures)
-            self.batch_captures.clear()
+        with self.batch_lock:
+            if self.batch_captures:
+                self._upload_batch(self.batch_captures)
+                self.batch_captures.clear()
 
     def _update_stats(self, key: str, value: int = 1):
         with self.stats_lock:

+ 205 - 1
dual_camera_system/tests/test_polling_tracker.py

@@ -1,5 +1,7 @@
 import sys
 import os
+import time
+import threading
 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
 import numpy as np
@@ -8,9 +10,36 @@ from polling_tracker import PollingTrackingCoordinator, CaptureRecord
 from tracker import TrackedPerson
 
 
+class _FakePosition:
+    def __init__(self):
+        self.pan = 0
+        self.tilt = 0
+        self.zoom = 1
+
+
 class FakePanorama:
     def __init__(self):
         self.frame = np.zeros((480, 640, 3), dtype=np.uint8)
+        self.connected = False
+        self.streaming = False
+        self.stopped = False
+
+    def connect(self):
+        self.connected = True
+        return True
+
+    def start_stream_rtsp(self):
+        self.streaming = True
+        return True
+
+    def stop_stream_rtsp(self):
+        self.stopped = True
+        return True
+
+    def disconnect(self):
+        self.connected = False
+        self.streaming = False
+        return True
 
     def get_frame(self):
         return self.frame.copy()
@@ -19,7 +48,18 @@ class FakePanorama:
 class FakePTZ:
     def __init__(self):
         self.commands = []
-        self.current_position = type("P", (), {"pan": 0, "tilt": 0, "zoom": 1})()
+        self.connected = False
+        self.ptz_frame = np.zeros((100, 100, 3), dtype=np.uint8)
+        self.current_position = _FakePosition()
+        self.ptz_config = {"default_zoom": 8}
+
+    def connect(self):
+        self.connected = True
+        return True
+
+    def disconnect(self):
+        self.connected = False
+        return True
 
     def goto_exact_position(self, pan, tilt, zoom):
         self.commands.append((pan, tilt, zoom))
@@ -31,14 +71,35 @@ class FakePTZ:
     def calculate_ptz_position(self, x, y, zoom=None):
         return x * 180, y * 90, zoom or 8
 
+    def get_frame(self):
+        return self.ptz_frame.copy()
+
 
 class FakeTracker:
     def __init__(self, persons):
         self.persons = persons
+        self.released = False
 
     def update(self, frame):
         return self.persons
 
+    def release(self):
+        self.released = True
+
+
+class FakeEventPusher:
+    def __init__(self):
+        self.uploads = []
+        self.pushes = []
+
+    def upload_numpy_image(self, image):
+        url = f"url_{id(image)}"
+        self.uploads.append(url)
+        return url
+
+    def push_tracking_capture(self, batch_time, captures):
+        self.pushes.append({"batch_time": batch_time, "captures": captures})
+
 
 def test_update_active_targets():
     pan = FakePanorama()
@@ -81,3 +142,146 @@ def test_capture_record_creation():
         confidence=0.9,
     )
     assert record.track_id == 1
+
+
+def test_target_lost_and_timeout_removal():
+    pan = FakePanorama()
+    ptz = FakePTZ()
+    tracker = FakeTracker([
+        TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(20, 30), confidence=0.9),
+    ])
+
+    coord = PollingTrackingCoordinator(
+        pan, ptz, tracker, config={"tracking_timeout": 0.2, "max_tracking_targets": 4}
+    )
+    frame = pan.get_frame()
+    coord._update_active_targets(tracker.update(frame), frame.shape)
+    assert 1 in coord.target_order
+    assert not coord.active_targets[1].lost
+
+    # 当前帧无目标:标记为丢失,但仍在 target_order 中
+    tracker.persons = []
+    coord._update_active_targets(tracker.update(frame), frame.shape)
+    assert 1 in coord.target_order
+    assert coord.active_targets[1].lost
+
+    # 超时后应被移除
+    time.sleep(0.25)
+    coord._update_active_targets(tracker.update(frame), frame.shape)
+    assert 1 not in coord.target_order
+    assert 1 not in coord.active_targets
+
+
+def test_target_lost_not_removed_before_timeout():
+    pan = FakePanorama()
+    ptz = FakePTZ()
+    tracker = FakeTracker([
+        TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(20, 30), confidence=0.9),
+    ])
+
+    coord = PollingTrackingCoordinator(
+        pan, ptz, tracker, config={"tracking_timeout": 1.0, "max_tracking_targets": 4}
+    )
+    frame = pan.get_frame()
+    coord._update_active_targets(tracker.update(frame), frame.shape)
+
+    tracker.persons = []
+    coord._update_active_targets(tracker.update(frame), frame.shape)
+    assert 1 in coord.target_order
+    assert coord.active_targets[1].lost
+
+
+def test_batch_upload_flush():
+    pan = FakePanorama()
+    ptz = FakePTZ()
+    tracker = FakeTracker([
+        TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(320, 240), confidence=0.9),
+    ])
+
+    coord = PollingTrackingCoordinator(
+        pan, ptz, tracker,
+        config={"ptz_stabilize_time": 0.01, "ptz_command_cooldown": 0.0, "enable_upload": True}
+    )
+    pusher = FakeEventPusher()
+    coord.set_event_pusher(pusher)
+
+    frame = pan.get_frame()
+    coord._update_active_targets(tracker.update(frame), frame.shape)
+
+    record = coord._capture_one(list(coord.active_targets.values())[0])
+    assert record is not None
+    coord.batch_captures.append(record)
+
+    coord._flush_batch_if_needed()
+    assert len(coord.batch_captures) == 0
+    assert len(pusher.pushes) == 1
+    assert len(pusher.uploads) == 2  # PTZ + panorama
+
+
+def test_pause_resume():
+    pan = FakePanorama()
+    ptz = FakePTZ()
+    tracker = FakeTracker([])
+
+    coord = PollingTrackingCoordinator(pan, ptz, tracker, config={})
+    coord.pause()
+    assert coord._paused is True
+    assert coord._paused_event.is_set() is False
+
+    coord.resume()
+    assert coord._paused is False
+    assert coord._paused_event.is_set() is True
+
+
+def test_thread_start_stop_lifecycle():
+    pan = FakePanorama()
+    ptz = FakePTZ()
+    tracker = FakeTracker([])
+
+    coord = PollingTrackingCoordinator(
+        pan, ptz, tracker,
+        config={"ptz_stabilize_time": 0.01, "ptz_command_cooldown": 0.0}
+    )
+    assert coord.start() is True
+    assert coord.running is True
+    assert pan.streaming is True
+    assert ptz.connected is True
+
+    time.sleep(0.05)
+    coord.stop()
+    assert coord.running is False
+    assert pan.stopped is True
+    assert tracker.released is True
+
+
+def test_ptz_worker_capture_flow():
+    pan = FakePanorama()
+    ptz = FakePTZ()
+    tracker = FakeTracker([
+        TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(320, 240), confidence=0.9),
+    ])
+
+    coord = PollingTrackingCoordinator(
+        pan, ptz, tracker,
+        config={
+            "ptz_stabilize_time": 0.01,
+            "ptz_command_cooldown": 0.0,
+            "max_capture_per_target": 1,
+        }
+    )
+
+    frame = pan.get_frame()
+    coord._update_active_targets(tracker.update(frame), frame.shape)
+
+    target = coord.active_targets[1]
+    record = coord._capture_one(target)
+
+    assert record is not None
+    assert record.track_id == 1
+    assert len(ptz.commands) == 1
+    with coord._capture_counts_lock:
+        assert coord._capture_counts[1] == 1
+
+    # 超过最大抓拍数后应返回 None
+    record2 = coord._capture_one(target)
+    assert record2 is None