|
@@ -0,0 +1,396 @@
|
|
|
|
|
+"""
|
|
|
|
|
+轮询跟踪 + PTZ 抓拍协调器
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+import os
|
|
|
|
|
+import time
|
|
|
|
|
+import threading
|
|
|
|
|
+import logging
|
|
|
|
|
+from typing import Dict, List, Tuple, Optional
|
|
|
|
|
+from dataclasses import dataclass
|
|
|
|
|
+
|
|
|
|
|
+import cv2
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+
|
|
|
|
|
+from config import TRACKING_CONFIG
|
|
|
|
|
+from tracker import UltralyticsTracker, TrackedPerson
|
|
|
|
|
+from coordinator import TargetSelector
|
|
|
|
|
+
|
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@dataclass
|
|
|
|
|
+class CaptureRecord:
|
|
|
|
|
+ """单次抓拍记录"""
|
|
|
|
|
+ track_id: int
|
|
|
|
|
+ timestamp: float
|
|
|
|
|
+ position: Tuple[float, float]
|
|
|
|
|
+ ptz_position: Tuple[float, float, int]
|
|
|
|
|
+ ptz_image: np.ndarray
|
|
|
|
|
+ panorama_image: Optional[np.ndarray]
|
|
|
|
|
+ confidence: float
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class PollingTrackingCoordinator:
|
|
|
|
|
+ """多目标轮询跟踪 + PTZ 抓拍协调器"""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ panorama_camera,
|
|
|
|
|
+ ptz_camera,
|
|
|
|
|
+ tracker: UltralyticsTracker,
|
|
|
|
|
+ config: Optional[Dict] = None,
|
|
|
|
|
+ calibrator=None,
|
|
|
|
|
+ ):
|
|
|
|
|
+ self.panorama = panorama_camera
|
|
|
|
|
+ self.ptz = ptz_camera
|
|
|
|
|
+ self.tracker = tracker
|
|
|
|
|
+ self.config = config or TRACKING_CONFIG
|
|
|
|
|
+ self.calibrator = calibrator
|
|
|
|
|
+
|
|
|
|
|
+ self.active_targets: Dict[int, TrackedPerson] = {}
|
|
|
|
|
+ self.target_order: List[int] = []
|
|
|
|
|
+ self.current_index: int = 0
|
|
|
|
|
+ self.batch_captures: List[CaptureRecord] = []
|
|
|
|
|
+ self._capture_counts: Dict[int, int] = {}
|
|
|
|
|
+
|
|
|
|
|
+ self.targets_lock = threading.Lock()
|
|
|
|
|
+ self.running = False
|
|
|
|
|
+ self._detection_thread = None
|
|
|
|
|
+ self._ptz_thread = None
|
|
|
|
|
+ self._paused = False
|
|
|
|
|
+ self._paused_event = threading.Event()
|
|
|
|
|
+ self._paused_event.set()
|
|
|
|
|
+
|
|
|
|
|
+ self.target_selector = TargetSelector(self.config.get("target_selection", {}))
|
|
|
|
|
+ self.event_pusher = None
|
|
|
|
|
+
|
|
|
|
|
+ self.stats = {
|
|
|
|
|
+ "frames_processed": 0,
|
|
|
|
|
+ "persons_detected": 0,
|
|
|
|
|
+ "captures": 0,
|
|
|
|
|
+ "uploads": 0,
|
|
|
|
|
+ "start_time": None,
|
|
|
|
|
+ }
|
|
|
|
|
+ self.stats_lock = threading.Lock()
|
|
|
|
|
+
|
|
|
|
|
+ self._ensure_capture_dir()
|
|
|
|
|
+
|
|
|
|
|
+ def set_event_pusher(self, event_pusher):
|
|
|
|
|
+ self.event_pusher = event_pusher
|
|
|
|
|
+
|
|
|
|
|
+ def _ensure_capture_dir(self):
|
|
|
|
|
+ capture_dir = self.config.get("capture_dir", "/home/admin/dsh/tracking_captures")
|
|
|
|
|
+ try:
|
|
|
|
|
+ os.makedirs(capture_dir, exist_ok=True)
|
|
|
|
|
+ except OSError as e:
|
|
|
|
|
+ logger.warning(f"无法创建抓拍目录 {capture_dir}: {e}")
|
|
|
|
|
+
|
|
|
|
|
+ def start(self) -> bool:
|
|
|
|
|
+ if not self.panorama.connect():
|
|
|
|
|
+ print("连接全景摄像头失败")
|
|
|
|
|
+ return False
|
|
|
|
|
+ if not self.ptz.connect():
|
|
|
|
|
+ print("连接球机失败")
|
|
|
|
|
+ self.panorama.disconnect()
|
|
|
|
|
+ return False
|
|
|
|
|
+ if not self.panorama.start_stream_rtsp():
|
|
|
|
|
+ print("启动全景视频流失败")
|
|
|
|
|
+ self.panorama.disconnect()
|
|
|
|
|
+ self.ptz.disconnect()
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ self.running = True
|
|
|
|
|
+ self._detection_thread = threading.Thread(target=self._detection_worker, daemon=True)
|
|
|
|
|
+ self._detection_thread.start()
|
|
|
|
|
+ self._ptz_thread = threading.Thread(target=self._ptz_worker, daemon=True)
|
|
|
|
|
+ self._ptz_thread.start()
|
|
|
|
|
+
|
|
|
|
|
+ with self.stats_lock:
|
|
|
|
|
+ self.stats["start_time"] = time.time()
|
|
|
|
|
+
|
|
|
|
|
+ print("轮询跟踪抓拍协调器已启动")
|
|
|
|
|
+ return True
|
|
|
|
|
+
|
|
|
|
|
+ def stop(self):
|
|
|
|
|
+ self.running = False
|
|
|
|
|
+ if self._detection_thread:
|
|
|
|
|
+ self._detection_thread.join(timeout=3)
|
|
|
|
|
+ if self._ptz_thread:
|
|
|
|
|
+ self._ptz_thread.join(timeout=3)
|
|
|
|
|
+ self.panorama.disconnect()
|
|
|
|
|
+ self.ptz.disconnect()
|
|
|
|
|
+ print("轮询跟踪抓拍协调器已停止")
|
|
|
|
|
+
|
|
|
|
|
+ def pause(self):
|
|
|
|
|
+ self._paused = True
|
|
|
|
|
+ self._paused_event.clear()
|
|
|
|
|
+
|
|
|
|
|
+ def resume(self):
|
|
|
|
|
+ self._paused = False
|
|
|
|
|
+ self._paused_event.set()
|
|
|
|
|
+
|
|
|
|
|
+ def _detection_worker(self):
|
|
|
|
|
+ self._paused_event.wait()
|
|
|
|
|
+ detection_fps = self.config.get("detection_fps", 2)
|
|
|
|
|
+ detection_interval = 1.0 / detection_fps
|
|
|
|
|
+ last_detection_time = 0
|
|
|
|
|
+
|
|
|
|
|
+ while self.running:
|
|
|
|
|
+ try:
|
|
|
|
|
+ frame = self.panorama.get_frame()
|
|
|
|
|
+ if frame is None:
|
|
|
|
|
+ time.sleep(0.01)
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ self._update_stats("frames_processed")
|
|
|
|
|
+
|
|
|
|
|
+ current_time = time.time()
|
|
|
|
|
+ if not self._paused and current_time - last_detection_time >= detection_interval:
|
|
|
|
|
+ last_detection_time = current_time
|
|
|
|
|
+ tracked = self.tracker.update(frame)
|
|
|
|
|
+ self._update_active_targets(tracked, frame.shape)
|
|
|
|
|
+ if tracked:
|
|
|
|
|
+ self._update_stats("persons_detected", len(tracked))
|
|
|
|
|
+
|
|
|
|
|
+ time.sleep(0.01)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"检测线程错误: {e}")
|
|
|
|
|
+ time.sleep(0.1)
|
|
|
|
|
+
|
|
|
|
|
+ def _update_active_targets(self, tracked: List[TrackedPerson], frame_shape):
|
|
|
|
|
+ current_time = time.time()
|
|
|
|
|
+ frame_h, frame_w = frame_shape[:2]
|
|
|
|
|
+ timeout = self.config.get("tracking_timeout", 3.0)
|
|
|
|
|
+ max_targets = self.config.get("max_tracking_targets", 4)
|
|
|
|
|
+
|
|
|
|
|
+ with self.targets_lock:
|
|
|
|
|
+ # 更新或新增
|
|
|
|
|
+ updated_ids = set()
|
|
|
|
|
+ for p in tracked:
|
|
|
|
|
+ if p.track_id < 0:
|
|
|
|
|
+ continue
|
|
|
|
|
+ updated_ids.add(p.track_id)
|
|
|
|
|
+ self.active_targets[p.track_id] = p
|
|
|
|
|
+ if p.track_id not in self.target_order:
|
|
|
|
|
+ self.target_order.append(p.track_id)
|
|
|
|
|
+
|
|
|
|
|
+ # 标记丢失
|
|
|
|
|
+ lost_ids = [tid for tid in self.target_order if tid not in updated_ids]
|
|
|
|
|
+ for tid in lost_ids:
|
|
|
|
|
+ t = self.active_targets.get(tid)
|
|
|
|
|
+ if t is not None:
|
|
|
|
|
+ t.lost = True
|
|
|
|
|
+
|
|
|
|
|
+ # 移除长期丢失
|
|
|
|
|
+ remove_ids = []
|
|
|
|
|
+ for tid in self.target_order:
|
|
|
|
|
+ t = self.active_targets.get(tid)
|
|
|
|
|
+ if t is None:
|
|
|
|
|
+ remove_ids.append(tid)
|
|
|
|
|
+ continue
|
|
|
|
|
+ # 简单丢失超时移除
|
|
|
|
|
+ if t.lost:
|
|
|
|
|
+ remove_ids.append(tid)
|
|
|
|
|
+
|
|
|
|
|
+ for tid in remove_ids:
|
|
|
|
|
+ if tid in self.target_order:
|
|
|
|
|
+ self.target_order.remove(tid)
|
|
|
|
|
+ self.active_targets.pop(tid, None)
|
|
|
|
|
+ self._capture_counts.pop(tid, None)
|
|
|
|
|
+
|
|
|
|
|
+ # 人数上限淘汰
|
|
|
|
|
+ if len(self.active_targets) > max_targets:
|
|
|
|
|
+ self._prune_targets(frame_w, frame_h, max_targets)
|
|
|
|
|
+
|
|
|
|
|
+ def _prune_targets(self, frame_w: int, frame_h: int, max_targets: int):
|
|
|
|
|
+ targets = list(self.active_targets.values())
|
|
|
|
|
+ frame_size = (frame_w, frame_h)
|
|
|
|
|
+ scored = []
|
|
|
|
|
+ for t in targets:
|
|
|
|
|
+ target_wrapper = type("T", (), {
|
|
|
|
|
+ "track_id": t.track_id,
|
|
|
|
|
+ "area": (t.bbox[2] - t.bbox[0]) * (t.bbox[3] - t.bbox[1]),
|
|
|
|
|
+ "confidence": t.confidence,
|
|
|
|
|
+ "center_distance": self._center_distance(t.center, frame_size),
|
|
|
|
|
+ "score": 0.0,
|
|
|
|
|
+ })()
|
|
|
|
|
+ target_wrapper.score = self.target_selector.calculate_score(target_wrapper, frame_size)
|
|
|
|
|
+ scored.append(target_wrapper)
|
|
|
|
|
+ scored.sort(key=lambda x: x.score, reverse=True)
|
|
|
|
|
+ keep_ids = {t.track_id for t in scored[:max_targets]}
|
|
|
|
|
+ remove_ids = [tid for tid in self.active_targets if tid not in keep_ids]
|
|
|
|
|
+ for tid in remove_ids:
|
|
|
|
|
+ self.active_targets.pop(tid, None)
|
|
|
|
|
+ if tid in self.target_order:
|
|
|
|
|
+ self.target_order.remove(tid)
|
|
|
|
|
+ self._capture_counts.pop(tid, None)
|
|
|
|
|
+
|
|
|
|
|
+ def _center_distance(self, center: Tuple[int, int], frame_size: Tuple[int, int]) -> float:
|
|
|
|
|
+ cx, cy = frame_size[0] / 2, frame_size[1] / 2
|
|
|
|
|
+ dx = abs(center[0] - cx) / cx
|
|
|
|
|
+ dy = abs(center[1] - cy) / cy
|
|
|
|
|
+ return (dx + dy) / 2
|
|
|
|
|
+
|
|
|
|
|
+ def _ptz_worker(self):
|
|
|
|
|
+ while self.running:
|
|
|
|
|
+ try:
|
|
|
|
|
+ if self._paused:
|
|
|
|
|
+ self._paused_event.wait()
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ with self.targets_lock:
|
|
|
|
|
+ has_targets = bool(self.active_targets)
|
|
|
|
|
+ target_order_snapshot = self.target_order.copy()
|
|
|
|
|
+
|
|
|
|
|
+ if not has_targets or not target_order_snapshot:
|
|
|
|
|
+ self._flush_batch_if_needed()
|
|
|
|
|
+ time.sleep(0.1)
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ if self.current_index >= len(target_order_snapshot):
|
|
|
|
|
+ self.current_index = 0
|
|
|
|
|
+
|
|
|
|
|
+ target_id = target_order_snapshot[self.current_index]
|
|
|
|
|
+
|
|
|
|
|
+ with self.targets_lock:
|
|
|
|
|
+ target = self.active_targets.get(target_id)
|
|
|
|
|
+
|
|
|
|
|
+ if target is None or target.lost:
|
|
|
|
|
+ self._advance(len(target_order_snapshot))
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ record = self._capture_one(target)
|
|
|
|
|
+ if record:
|
|
|
|
|
+ self.batch_captures.append(record)
|
|
|
|
|
+ self._update_stats("captures")
|
|
|
|
|
+
|
|
|
|
|
+ self._advance(len(target_order_snapshot))
|
|
|
|
|
+
|
|
|
|
|
+ # 一轮完成
|
|
|
|
|
+ if self.current_index == 0 and self.batch_captures:
|
|
|
|
|
+ self._upload_batch(self.batch_captures)
|
|
|
|
|
+ self.batch_captures.clear()
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"PTZ 线程错误: {e}")
|
|
|
|
|
+ time.sleep(0.1)
|
|
|
|
|
+
|
|
|
|
|
+ def _advance(self, order_len: int = None):
|
|
|
|
|
+ if order_len is None:
|
|
|
|
|
+ order_len = len(self.target_order) or 1
|
|
|
|
|
+ self.current_index = (self.current_index + 1) % order_len
|
|
|
|
|
+
|
|
|
|
|
+ def _capture_one(self, target: TrackedPerson) -> Optional[CaptureRecord]:
|
|
|
|
|
+ frame = self.panorama.get_frame()
|
|
|
|
|
+ if frame is None:
|
|
|
|
|
+ return None
|
|
|
|
|
+ frame_h, frame_w = frame.shape[:2]
|
|
|
|
|
+
|
|
|
|
|
+ x_ratio = target.center[0] / frame_w
|
|
|
|
|
+ y_ratio = target.center[1] / frame_h
|
|
|
|
|
+
|
|
|
|
|
+ if self.calibrator and self.calibrator.is_calibrated():
|
|
|
|
|
+ pan, tilt = self.calibrator.transform(x_ratio, y_ratio)
|
|
|
|
|
+ ptz_config = getattr(self.ptz, "ptz_config", {})
|
|
|
|
|
+ if ptz_config.get("pan_flip"):
|
|
|
|
|
+ pan = (pan + 180) % 360
|
|
|
|
|
+ zoom = ptz_config.get("default_zoom", 8)
|
|
|
|
|
+ else:
|
|
|
|
|
+ pan, tilt, zoom = self.ptz.calculate_ptz_position(x_ratio, y_ratio)
|
|
|
|
|
+
|
|
|
|
|
+ success = self.ptz.goto_exact_position(pan, tilt, zoom)
|
|
|
|
|
+ if not success:
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ time.sleep(self.config.get("ptz_stabilize_time", 2.0))
|
|
|
|
|
+ ptz_frame = self._get_clear_ptz_frame()
|
|
|
|
|
+ if ptz_frame is None:
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ max_cap = self.config.get("max_capture_per_target", 0)
|
|
|
|
|
+ if max_cap > 0 and self._capture_counts.get(target.track_id, 0) >= max_cap:
|
|
|
|
|
+ return None
|
|
|
|
|
+ self._capture_counts[target.track_id] = self._capture_counts.get(target.track_id, 0) + 1
|
|
|
|
|
+
|
|
|
|
|
+ panorama_image = frame.copy() if self.config.get("save_panorama_pair", True) else None
|
|
|
|
|
+
|
|
|
|
|
+ # 本地保存
|
|
|
|
|
+ self._save_local(ptz_frame, panorama_image, target, pan, tilt, zoom)
|
|
|
|
|
+
|
|
|
|
|
+ return CaptureRecord(
|
|
|
|
|
+ track_id=target.track_id,
|
|
|
|
|
+ timestamp=time.time(),
|
|
|
|
|
+ position=(x_ratio, y_ratio),
|
|
|
|
|
+ ptz_position=(pan, tilt, zoom),
|
|
|
|
|
+ ptz_image=ptz_frame,
|
|
|
|
|
+ panorama_image=panorama_image,
|
|
|
|
|
+ confidence=target.confidence,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def _get_clear_ptz_frame(self, max_attempts: int = 5, wait_interval: float = 0.2) -> Optional[np.ndarray]:
|
|
|
|
|
+ best_frame = None
|
|
|
|
|
+ best_score = -1
|
|
|
|
|
+ for _ in range(max_attempts):
|
|
|
|
|
+ frame = self.ptz.get_frame()
|
|
|
|
|
+ if frame is not None:
|
|
|
|
|
+ frame_copy = frame.copy()
|
|
|
|
|
+ gray = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2GRAY)
|
|
|
|
|
+ score = cv2.Laplacian(gray, cv2.CV_64F).var()
|
|
|
|
|
+ if score > best_score:
|
|
|
|
|
+ best_score = score
|
|
|
|
|
+ best_frame = frame_copy
|
|
|
|
|
+ time.sleep(wait_interval)
|
|
|
|
|
+ return best_frame
|
|
|
|
|
+
|
|
|
|
|
+ def _save_local(self, ptz_frame, panorama_image, target, pan, tilt, zoom):
|
|
|
|
|
+ capture_dir = self.config.get("capture_dir", "/home/admin/dsh/tracking_captures")
|
|
|
|
|
+ try:
|
|
|
|
|
+ os.makedirs(capture_dir, exist_ok=True)
|
|
|
|
|
+ except OSError as e:
|
|
|
|
|
+ logger.warning(f"无法创建抓拍目录 {capture_dir}: {e}")
|
|
|
|
|
+ return
|
|
|
|
|
+ timestamp = int(time.time() * 1000)
|
|
|
|
|
+ base = f"{capture_dir}/ptz_{target.track_id}_{timestamp}_{pan:.0f}_{tilt:.0f}_z{zoom}.jpg"
|
|
|
|
|
+ cv2.imwrite(base, ptz_frame)
|
|
|
|
|
+ if panorama_image is not None:
|
|
|
|
|
+ pan_base = f"{capture_dir}/panorama_{target.track_id}_{timestamp}.jpg"
|
|
|
|
|
+ cv2.imwrite(pan_base, panorama_image)
|
|
|
|
|
+
|
|
|
|
|
+ def _upload_batch(self, records: List[CaptureRecord]):
|
|
|
|
|
+ if not self.event_pusher or not self.config.get("enable_upload", True):
|
|
|
|
|
+ return
|
|
|
|
|
+ try:
|
|
|
|
|
+ uploads = []
|
|
|
|
|
+ for r in records:
|
|
|
|
|
+ ptz_url = self.event_pusher.upload_numpy_image(r.ptz_image)
|
|
|
|
|
+ pan_url = None
|
|
|
|
|
+ if r.panorama_image is not None:
|
|
|
|
|
+ pan_url = self.event_pusher.upload_numpy_image(r.panorama_image)
|
|
|
|
|
+ uploads.append({
|
|
|
|
|
+ "track_id": r.track_id,
|
|
|
|
|
+ "ptz_image_url": ptz_url,
|
|
|
|
|
+ "panorama_image_url": pan_url,
|
|
|
|
|
+ "position": r.position,
|
|
|
|
|
+ "ptz_position": r.ptz_position,
|
|
|
|
|
+ "confidence": r.confidence,
|
|
|
|
|
+ "timestamp": r.timestamp,
|
|
|
|
|
+ })
|
|
|
|
|
+ self.event_pusher.push_tracking_capture(batch_time=time.time(), captures=uploads)
|
|
|
|
|
+ self._update_stats("uploads")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"批量上传失败: {e}")
|
|
|
|
|
+
|
|
|
|
|
+ def _flush_batch_if_needed(self):
|
|
|
|
|
+ if self.batch_captures:
|
|
|
|
|
+ self._upload_batch(self.batch_captures)
|
|
|
|
|
+ self.batch_captures.clear()
|
|
|
|
|
+
|
|
|
|
|
+ def _update_stats(self, key: str, value: int = 1):
|
|
|
|
|
+ with self.stats_lock:
|
|
|
|
|
+ if key in self.stats:
|
|
|
|
|
+ self.stats[key] += value
|
|
|
|
|
+
|
|
|
|
|
+ def get_stats(self) -> Dict:
|
|
|
|
|
+ with self.stats_lock:
|
|
|
|
|
+ return self.stats.copy()
|