Procházet zdrojové kódy

feat: add PollingTrackingCoordinator for multi-target PTZ capture

wenhongquan před 1 dnem
rodič
revize
dec0921b9a

+ 396 - 0
dual_camera_system/polling_tracker.py

@@ -0,0 +1,396 @@
+"""
+轮询跟踪 + PTZ 抓拍协调器
+"""
+
+import os
+import time
+import threading
+import logging
+from typing import Dict, List, Tuple, Optional
+from dataclasses import dataclass
+
+import cv2
+import numpy as np
+
+from config import TRACKING_CONFIG
+from tracker import UltralyticsTracker, TrackedPerson
+from coordinator import TargetSelector
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class CaptureRecord:
+    """单次抓拍记录"""
+    track_id: int
+    timestamp: float
+    position: Tuple[float, float]
+    ptz_position: Tuple[float, float, int]
+    ptz_image: np.ndarray
+    panorama_image: Optional[np.ndarray]
+    confidence: float
+
+
+class PollingTrackingCoordinator:
+    """多目标轮询跟踪 + PTZ 抓拍协调器"""
+
+    def __init__(
+        self,
+        panorama_camera,
+        ptz_camera,
+        tracker: UltralyticsTracker,
+        config: Optional[Dict] = None,
+        calibrator=None,
+    ):
+        self.panorama = panorama_camera
+        self.ptz = ptz_camera
+        self.tracker = tracker
+        self.config = config or TRACKING_CONFIG
+        self.calibrator = calibrator
+
+        self.active_targets: Dict[int, TrackedPerson] = {}
+        self.target_order: List[int] = []
+        self.current_index: int = 0
+        self.batch_captures: List[CaptureRecord] = []
+        self._capture_counts: Dict[int, int] = {}
+
+        self.targets_lock = threading.Lock()
+        self.running = False
+        self._detection_thread = None
+        self._ptz_thread = None
+        self._paused = False
+        self._paused_event = threading.Event()
+        self._paused_event.set()
+
+        self.target_selector = TargetSelector(self.config.get("target_selection", {}))
+        self.event_pusher = None
+
+        self.stats = {
+            "frames_processed": 0,
+            "persons_detected": 0,
+            "captures": 0,
+            "uploads": 0,
+            "start_time": None,
+        }
+        self.stats_lock = threading.Lock()
+
+        self._ensure_capture_dir()
+
+    def set_event_pusher(self, event_pusher):
+        self.event_pusher = event_pusher
+
+    def _ensure_capture_dir(self):
+        capture_dir = self.config.get("capture_dir", "/home/admin/dsh/tracking_captures")
+        try:
+            os.makedirs(capture_dir, exist_ok=True)
+        except OSError as e:
+            logger.warning(f"无法创建抓拍目录 {capture_dir}: {e}")
+
+    def start(self) -> bool:
+        if not self.panorama.connect():
+            print("连接全景摄像头失败")
+            return False
+        if not self.ptz.connect():
+            print("连接球机失败")
+            self.panorama.disconnect()
+            return False
+        if not self.panorama.start_stream_rtsp():
+            print("启动全景视频流失败")
+            self.panorama.disconnect()
+            self.ptz.disconnect()
+            return False
+
+        self.running = True
+        self._detection_thread = threading.Thread(target=self._detection_worker, daemon=True)
+        self._detection_thread.start()
+        self._ptz_thread = threading.Thread(target=self._ptz_worker, daemon=True)
+        self._ptz_thread.start()
+
+        with self.stats_lock:
+            self.stats["start_time"] = time.time()
+
+        print("轮询跟踪抓拍协调器已启动")
+        return True
+
+    def stop(self):
+        self.running = False
+        if self._detection_thread:
+            self._detection_thread.join(timeout=3)
+        if self._ptz_thread:
+            self._ptz_thread.join(timeout=3)
+        self.panorama.disconnect()
+        self.ptz.disconnect()
+        print("轮询跟踪抓拍协调器已停止")
+
+    def pause(self):
+        self._paused = True
+        self._paused_event.clear()
+
+    def resume(self):
+        self._paused = False
+        self._paused_event.set()
+
+    def _detection_worker(self):
+        self._paused_event.wait()
+        detection_fps = self.config.get("detection_fps", 2)
+        detection_interval = 1.0 / detection_fps
+        last_detection_time = 0
+
+        while self.running:
+            try:
+                frame = self.panorama.get_frame()
+                if frame is None:
+                    time.sleep(0.01)
+                    continue
+
+                self._update_stats("frames_processed")
+
+                current_time = time.time()
+                if not self._paused and current_time - last_detection_time >= detection_interval:
+                    last_detection_time = current_time
+                    tracked = self.tracker.update(frame)
+                    self._update_active_targets(tracked, frame.shape)
+                    if tracked:
+                        self._update_stats("persons_detected", len(tracked))
+
+                time.sleep(0.01)
+            except Exception as e:
+                logger.error(f"检测线程错误: {e}")
+                time.sleep(0.1)
+
+    def _update_active_targets(self, tracked: List[TrackedPerson], frame_shape):
+        current_time = time.time()
+        frame_h, frame_w = frame_shape[:2]
+        timeout = self.config.get("tracking_timeout", 3.0)
+        max_targets = self.config.get("max_tracking_targets", 4)
+
+        with self.targets_lock:
+            # 更新或新增
+            updated_ids = set()
+            for p in tracked:
+                if p.track_id < 0:
+                    continue
+                updated_ids.add(p.track_id)
+                self.active_targets[p.track_id] = p
+                if p.track_id not in self.target_order:
+                    self.target_order.append(p.track_id)
+
+            # 标记丢失
+            lost_ids = [tid for tid in self.target_order if tid not in updated_ids]
+            for tid in lost_ids:
+                t = self.active_targets.get(tid)
+                if t is not None:
+                    t.lost = True
+
+            # 移除长期丢失
+            remove_ids = []
+            for tid in self.target_order:
+                t = self.active_targets.get(tid)
+                if t is None:
+                    remove_ids.append(tid)
+                    continue
+                # 简单丢失超时移除
+                if t.lost:
+                    remove_ids.append(tid)
+
+            for tid in remove_ids:
+                if tid in self.target_order:
+                    self.target_order.remove(tid)
+                self.active_targets.pop(tid, None)
+                self._capture_counts.pop(tid, None)
+
+            # 人数上限淘汰
+            if len(self.active_targets) > max_targets:
+                self._prune_targets(frame_w, frame_h, max_targets)
+
+    def _prune_targets(self, frame_w: int, frame_h: int, max_targets: int):
+        targets = list(self.active_targets.values())
+        frame_size = (frame_w, frame_h)
+        scored = []
+        for t in targets:
+            target_wrapper = type("T", (), {
+                "track_id": t.track_id,
+                "area": (t.bbox[2] - t.bbox[0]) * (t.bbox[3] - t.bbox[1]),
+                "confidence": t.confidence,
+                "center_distance": self._center_distance(t.center, frame_size),
+                "score": 0.0,
+            })()
+            target_wrapper.score = self.target_selector.calculate_score(target_wrapper, frame_size)
+            scored.append(target_wrapper)
+        scored.sort(key=lambda x: x.score, reverse=True)
+        keep_ids = {t.track_id for t in scored[:max_targets]}
+        remove_ids = [tid for tid in self.active_targets if tid not in keep_ids]
+        for tid in remove_ids:
+            self.active_targets.pop(tid, None)
+            if tid in self.target_order:
+                self.target_order.remove(tid)
+            self._capture_counts.pop(tid, None)
+
+    def _center_distance(self, center: Tuple[int, int], frame_size: Tuple[int, int]) -> float:
+        cx, cy = frame_size[0] / 2, frame_size[1] / 2
+        dx = abs(center[0] - cx) / cx
+        dy = abs(center[1] - cy) / cy
+        return (dx + dy) / 2
+
+    def _ptz_worker(self):
+        while self.running:
+            try:
+                if self._paused:
+                    self._paused_event.wait()
+                    continue
+
+                with self.targets_lock:
+                    has_targets = bool(self.active_targets)
+                    target_order_snapshot = self.target_order.copy()
+
+                if not has_targets or not target_order_snapshot:
+                    self._flush_batch_if_needed()
+                    time.sleep(0.1)
+                    continue
+
+                if self.current_index >= len(target_order_snapshot):
+                    self.current_index = 0
+
+                target_id = target_order_snapshot[self.current_index]
+
+                with self.targets_lock:
+                    target = self.active_targets.get(target_id)
+
+                if target is None or target.lost:
+                    self._advance(len(target_order_snapshot))
+                    continue
+
+                record = self._capture_one(target)
+                if record:
+                    self.batch_captures.append(record)
+                    self._update_stats("captures")
+
+                self._advance(len(target_order_snapshot))
+
+                # 一轮完成
+                if self.current_index == 0 and self.batch_captures:
+                    self._upload_batch(self.batch_captures)
+                    self.batch_captures.clear()
+
+            except Exception as e:
+                logger.error(f"PTZ 线程错误: {e}")
+                time.sleep(0.1)
+
+    def _advance(self, order_len: int = None):
+        if order_len is None:
+            order_len = len(self.target_order) or 1
+        self.current_index = (self.current_index + 1) % order_len
+
+    def _capture_one(self, target: TrackedPerson) -> Optional[CaptureRecord]:
+        frame = self.panorama.get_frame()
+        if frame is None:
+            return None
+        frame_h, frame_w = frame.shape[:2]
+
+        x_ratio = target.center[0] / frame_w
+        y_ratio = target.center[1] / frame_h
+
+        if self.calibrator and self.calibrator.is_calibrated():
+            pan, tilt = self.calibrator.transform(x_ratio, y_ratio)
+            ptz_config = getattr(self.ptz, "ptz_config", {})
+            if ptz_config.get("pan_flip"):
+                pan = (pan + 180) % 360
+            zoom = ptz_config.get("default_zoom", 8)
+        else:
+            pan, tilt, zoom = self.ptz.calculate_ptz_position(x_ratio, y_ratio)
+
+        success = self.ptz.goto_exact_position(pan, tilt, zoom)
+        if not success:
+            return None
+
+        time.sleep(self.config.get("ptz_stabilize_time", 2.0))
+        ptz_frame = self._get_clear_ptz_frame()
+        if ptz_frame is None:
+            return None
+
+        max_cap = self.config.get("max_capture_per_target", 0)
+        if max_cap > 0 and self._capture_counts.get(target.track_id, 0) >= max_cap:
+            return None
+        self._capture_counts[target.track_id] = self._capture_counts.get(target.track_id, 0) + 1
+
+        panorama_image = frame.copy() if self.config.get("save_panorama_pair", True) else None
+
+        # 本地保存
+        self._save_local(ptz_frame, panorama_image, target, pan, tilt, zoom)
+
+        return CaptureRecord(
+            track_id=target.track_id,
+            timestamp=time.time(),
+            position=(x_ratio, y_ratio),
+            ptz_position=(pan, tilt, zoom),
+            ptz_image=ptz_frame,
+            panorama_image=panorama_image,
+            confidence=target.confidence,
+        )
+
+    def _get_clear_ptz_frame(self, max_attempts: int = 5, wait_interval: float = 0.2) -> Optional[np.ndarray]:
+        best_frame = None
+        best_score = -1
+        for _ in range(max_attempts):
+            frame = self.ptz.get_frame()
+            if frame is not None:
+                frame_copy = frame.copy()
+                gray = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2GRAY)
+                score = cv2.Laplacian(gray, cv2.CV_64F).var()
+                if score > best_score:
+                    best_score = score
+                    best_frame = frame_copy
+            time.sleep(wait_interval)
+        return best_frame
+
+    def _save_local(self, ptz_frame, panorama_image, target, pan, tilt, zoom):
+        capture_dir = self.config.get("capture_dir", "/home/admin/dsh/tracking_captures")
+        try:
+            os.makedirs(capture_dir, exist_ok=True)
+        except OSError as e:
+            logger.warning(f"无法创建抓拍目录 {capture_dir}: {e}")
+            return
+        timestamp = int(time.time() * 1000)
+        base = f"{capture_dir}/ptz_{target.track_id}_{timestamp}_{pan:.0f}_{tilt:.0f}_z{zoom}.jpg"
+        cv2.imwrite(base, ptz_frame)
+        if panorama_image is not None:
+            pan_base = f"{capture_dir}/panorama_{target.track_id}_{timestamp}.jpg"
+            cv2.imwrite(pan_base, panorama_image)
+
+    def _upload_batch(self, records: List[CaptureRecord]):
+        if not self.event_pusher or not self.config.get("enable_upload", True):
+            return
+        try:
+            uploads = []
+            for r in records:
+                ptz_url = self.event_pusher.upload_numpy_image(r.ptz_image)
+                pan_url = None
+                if r.panorama_image is not None:
+                    pan_url = self.event_pusher.upload_numpy_image(r.panorama_image)
+                uploads.append({
+                    "track_id": r.track_id,
+                    "ptz_image_url": ptz_url,
+                    "panorama_image_url": pan_url,
+                    "position": r.position,
+                    "ptz_position": r.ptz_position,
+                    "confidence": r.confidence,
+                    "timestamp": r.timestamp,
+                })
+            self.event_pusher.push_tracking_capture(batch_time=time.time(), captures=uploads)
+            self._update_stats("uploads")
+        except Exception as e:
+            logger.error(f"批量上传失败: {e}")
+
+    def _flush_batch_if_needed(self):
+        if self.batch_captures:
+            self._upload_batch(self.batch_captures)
+            self.batch_captures.clear()
+
+    def _update_stats(self, key: str, value: int = 1):
+        with self.stats_lock:
+            if key in self.stats:
+                self.stats[key] += value
+
+    def get_stats(self) -> Dict:
+        with self.stats_lock:
+            return self.stats.copy()

+ 83 - 0
dual_camera_system/tests/test_polling_tracker.py

@@ -0,0 +1,83 @@
+import sys
+import os
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import numpy as np
+import pytest
+from polling_tracker import PollingTrackingCoordinator, CaptureRecord
+from tracker import TrackedPerson
+
+
+class FakePanorama:
+    def __init__(self):
+        self.frame = np.zeros((480, 640, 3), dtype=np.uint8)
+
+    def get_frame(self):
+        return self.frame.copy()
+
+
+class FakePTZ:
+    def __init__(self):
+        self.commands = []
+        self.current_position = type("P", (), {"pan": 0, "tilt": 0, "zoom": 1})()
+
+    def goto_exact_position(self, pan, tilt, zoom):
+        self.commands.append((pan, tilt, zoom))
+        return True
+
+    def get_current_position(self):
+        return self.current_position
+
+    def calculate_ptz_position(self, x, y, zoom=None):
+        return x * 180, y * 90, zoom or 8
+
+
+class FakeTracker:
+    def __init__(self, persons):
+        self.persons = persons
+
+    def update(self, frame):
+        return self.persons
+
+
+def test_update_active_targets():
+    pan = FakePanorama()
+    ptz = FakePTZ()
+    tracker = FakeTracker([
+        TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(20, 30), confidence=0.9),
+        TrackedPerson(track_id=2, bbox=(50, 60, 70, 80), center=(60, 70), confidence=0.8),
+    ])
+
+    coord = PollingTrackingCoordinator(pan, ptz, tracker, config={"max_tracking_targets": 4})
+    frame = pan.get_frame()
+    coord._update_active_targets(tracker.update(frame), frame.shape)
+
+    assert len(coord.active_targets) == 2
+    assert 1 in coord.target_order
+    assert 2 in coord.target_order
+
+
+def test_advance_loop():
+    coord = PollingTrackingCoordinator.__new__(PollingTrackingCoordinator)
+    coord.target_order = [1, 2, 3]
+    coord.current_index = 0
+
+    coord._advance()
+    assert coord.current_index == 1
+
+    coord.current_index = 2
+    coord._advance()
+    assert coord.current_index == 0
+
+
+def test_capture_record_creation():
+    record = CaptureRecord(
+        track_id=1,
+        timestamp=1.0,
+        position=(0.5, 0.5),
+        ptz_position=(90.0, 45.0, 8),
+        ptz_image=np.zeros((100, 100, 3), dtype=np.uint8),
+        panorama_image=None,
+        confidence=0.9,
+    )
+    assert record.track_id == 1