| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472 |
- """
- 轮询跟踪 + PTZ 抓拍协调器
- """
- import os
- import time
- import threading
- import logging
- from typing import Dict, List, Tuple, Optional
- from dataclasses import dataclass
- import cv2
- import numpy as np
- from config import TRACKING_CONFIG
- from tracker import UltralyticsTracker, TrackedPerson
- from coordinator import TargetSelector, TrackingTarget
- logger = logging.getLogger(__name__)
- @dataclass
- class CaptureRecord:
- """单次抓拍记录"""
- track_id: int
- timestamp: float
- position: Tuple[float, float]
- ptz_position: Tuple[float, float, int]
- ptz_image: np.ndarray
- panorama_image: Optional[np.ndarray]
- confidence: float
- class PollingTrackingCoordinator:
- """多目标轮询跟踪 + PTZ 抓拍协调器"""
- def __init__(
- self,
- panorama_camera,
- ptz_camera,
- tracker: UltralyticsTracker,
- config: Optional[Dict] = None,
- calibrator=None,
- ):
- self.panorama = panorama_camera
- self.ptz = ptz_camera
- self.tracker = tracker
- self.config = config or TRACKING_CONFIG
- self.calibrator = calibrator
- self.active_targets: Dict[int, TrackedPerson] = {}
- self.target_order: List[int] = []
- self.current_index: int = 0
- self.batch_captures: List[CaptureRecord] = []
- self._capture_counts: Dict[int, int] = {}
- self._last_seen_time: Dict[int, float] = {}
- self.targets_lock = threading.Lock()
- self.batch_lock = threading.Lock()
- self._capture_counts_lock = threading.Lock()
- self.running = False
- self._detection_thread = None
- self._ptz_thread = None
- self._paused = False
- self._paused_event = threading.Event()
- self._paused_event.set()
- self._last_ptz_command_time = 0.0
- self._last_ptz_command_time_lock = threading.Lock()
- self.target_selector = TargetSelector(self.config.get("target_selection", {}))
- self.event_pusher = None
- self.stats = {
- "frames_processed": 0,
- "persons_detected": 0,
- "captures": 0,
- "uploads": 0,
- "start_time": None,
- }
- self.stats_lock = threading.Lock()
- self._ensure_capture_dir()
- def set_event_pusher(self, event_pusher):
- self.event_pusher = event_pusher
- def _ensure_capture_dir(self):
- capture_dir = self.config.get("capture_dir", "/home/admin/dsh/tracking_captures")
- try:
- os.makedirs(capture_dir, exist_ok=True)
- except OSError as e:
- logger.warning(f"无法创建抓拍目录 {capture_dir}: {e}")
- def start(self) -> bool:
- if not self.panorama.connect():
- logger.error("连接全景摄像头失败")
- return False
- if not self.ptz.connect():
- logger.error("连接球机失败")
- self.panorama.disconnect()
- return False
- if not self.panorama.start_stream_rtsp():
- logger.error("启动全景视频流失败")
- self.panorama.disconnect()
- self.ptz.disconnect()
- return False
- self.running = True
- self._detection_thread = threading.Thread(target=self._detection_worker, daemon=True)
- self._detection_thread.start()
- self._ptz_thread = threading.Thread(target=self._ptz_worker, daemon=True)
- self._ptz_thread.start()
- with self.stats_lock:
- self.stats["start_time"] = time.time()
- logger.info("轮询跟踪抓拍协调器已启动")
- return True
- def stop(self):
- self.running = False
- self._paused_event.set()
- if self._detection_thread:
- self._detection_thread.join(timeout=3)
- if self._ptz_thread:
- self._ptz_thread.join(timeout=3)
- # 刷新待上传批次
- self._flush_batch_if_needed()
- # 停止视频流后再断开连接
- if hasattr(self.panorama, "stop_stream_rtsp"):
- try:
- self.panorama.stop_stream_rtsp()
- except Exception as e:
- logger.warning(f"停止全景视频流失败: {e}")
- self.panorama.disconnect()
- self.ptz.disconnect()
- # 释放跟踪器资源
- if self.tracker is not None and hasattr(self.tracker, "release"):
- try:
- self.tracker.release()
- except Exception as e:
- logger.warning(f"释放跟踪器失败: {e}")
- logger.info("轮询跟踪抓拍协调器已停止")
- def pause(self):
- self._paused = True
- self._paused_event.clear()
- def resume(self):
- self._paused = False
- self._paused_event.set()
- def pause_detection(self):
- """暂停检测(兼容既有协调器接口)"""
- self.pause()
- def resume_detection(self):
- """恢复检测(兼容既有协调器接口)"""
- self.resume()
- def _detection_worker(self):
- self._paused_event.wait()
- detection_fps = self.config.get("detection_fps", 2)
- detection_interval = 1.0 / detection_fps
- last_detection_time = 0
- while self.running:
- try:
- # 暂停时阻塞等待,避免忙等
- if self._paused:
- self._paused_event.wait()
- continue
- frame = self.panorama.get_frame()
- if frame is None:
- time.sleep(0.01)
- continue
- self._update_stats("frames_processed")
- current_time = time.time()
- if current_time - last_detection_time >= detection_interval:
- last_detection_time = current_time
- tracked = self.tracker.update(frame)
- self._update_active_targets(tracked, frame.shape)
- if tracked:
- self._update_stats("persons_detected", len(tracked))
- time.sleep(0.01)
- except Exception as e:
- logger.error(f"检测线程错误: {e}")
- time.sleep(0.1)
- def _update_active_targets(self, tracked: List[TrackedPerson], frame_shape):
- current_time = time.time()
- frame_h, frame_w = frame_shape[:2]
- timeout = self.config.get("tracking_timeout", 3.0)
- max_targets = self.config.get("max_tracking_targets", 4)
- with self.targets_lock:
- # 更新或新增
- updated_ids = set()
- for p in tracked:
- if p.track_id < 0:
- continue
- updated_ids.add(p.track_id)
- p.lost = False
- self.active_targets[p.track_id] = p
- self._last_seen_time[p.track_id] = current_time
- if p.track_id not in self.target_order:
- self.target_order.append(p.track_id)
- # 标记丢失
- for tid in self.target_order:
- if tid not in updated_ids:
- t = self.active_targets.get(tid)
- if t is not None:
- t.lost = True
- # 移除长期丢失(超过 tracking_timeout)
- remove_ids = []
- for tid in self.target_order:
- t = self.active_targets.get(tid)
- if t is None:
- remove_ids.append(tid)
- continue
- if t.lost:
- last_seen = self._last_seen_time.get(tid, current_time)
- if current_time - last_seen >= timeout:
- remove_ids.append(tid)
- for tid in remove_ids:
- if tid in self.target_order:
- self.target_order.remove(tid)
- self.active_targets.pop(tid, None)
- self._last_seen_time.pop(tid, None)
- with self._capture_counts_lock:
- self._capture_counts.pop(tid, None)
- # 人数上限淘汰
- if len(self.active_targets) > max_targets:
- self._prune_targets(frame_w, frame_h, max_targets)
- def _prune_targets(self, frame_w: int, frame_h: int, max_targets: int):
- targets = list(self.active_targets.values())
- frame_size = (frame_w, frame_h)
- scored = []
- for t in targets:
- area = (t.bbox[2] - t.bbox[0]) * (t.bbox[3] - t.bbox[1])
- center_distance = self._center_distance(t.center, frame_size)
- target = TrackingTarget(
- track_id=t.track_id,
- position=(t.center[0] / frame_w, t.center[1] / frame_h),
- last_update=time.time(),
- area=area,
- confidence=t.confidence,
- center_distance=center_distance,
- )
- target.score = self.target_selector.calculate_score(target, frame_size)
- scored.append(target)
- scored.sort(key=lambda x: x.score, reverse=True)
- keep_ids = {t.track_id for t in scored[:max_targets]}
- remove_ids = [tid for tid in self.active_targets if tid not in keep_ids]
- for tid in remove_ids:
- self.active_targets.pop(tid, None)
- self._last_seen_time.pop(tid, None)
- if tid in self.target_order:
- self.target_order.remove(tid)
- with self._capture_counts_lock:
- self._capture_counts.pop(tid, None)
- def _center_distance(self, center: Tuple[int, int], frame_size: Tuple[int, int]) -> float:
- cx, cy = frame_size[0] / 2, frame_size[1] / 2
- dx = abs(center[0] - cx) / cx
- dy = abs(center[1] - cy) / cy
- return (dx + dy) / 2
- def _ptz_worker(self):
- while self.running:
- try:
- # 暂停时阻塞等待恢复
- if self._paused:
- self._paused_event.wait()
- continue
- # 原子性获取目标快照和当前目标
- with self.targets_lock:
- target_order_snapshot = self.target_order.copy()
- has_targets = bool(self.active_targets)
- if target_order_snapshot:
- if self.current_index >= len(target_order_snapshot):
- self.current_index = 0
- target_id = target_order_snapshot[self.current_index]
- target = self.active_targets.get(target_id)
- else:
- target_id = None
- target = None
- if not has_targets or not target_order_snapshot:
- self._flush_batch_if_needed()
- time.sleep(0.1)
- continue
- if target is None or target.lost:
- # 目标丢失时跳过但保留在队列中,并短暂休眠避免忙等
- time.sleep(0.01)
- with self.targets_lock:
- self._advance(len(target_order_snapshot))
- continue
- record = self._capture_one(target)
- if record:
- with self.batch_lock:
- self.batch_captures.append(record)
- self._update_stats("captures")
- with self.targets_lock:
- self._advance(len(target_order_snapshot))
- # 一轮完成
- with self.batch_lock:
- if self.current_index == 0 and self.batch_captures:
- self._upload_batch(self.batch_captures)
- self.batch_captures.clear()
- except Exception as e:
- logger.error(f"PTZ 线程错误: {e}")
- time.sleep(0.1)
- def _advance(self, order_len: int = None):
- if order_len is None:
- order_len = len(self.target_order) or 1
- self.current_index = (self.current_index + 1) % order_len
- def _capture_one(self, target: TrackedPerson) -> Optional[CaptureRecord]:
- frame = self.panorama.get_frame()
- if frame is None:
- return None
- frame_h, frame_w = frame.shape[:2]
- x_ratio = target.center[0] / frame_w
- y_ratio = target.center[1] / frame_h
- if self.calibrator and self.calibrator.is_calibrated():
- # 校准器已通过真实扫描建立全景坐标到球机 PTZ 角度的映射,
- # 返回的角度可直接发送给球机,不应再应用 pan_flip。
- pan, tilt = self.calibrator.transform(x_ratio, y_ratio)
- ptz_config = getattr(self.ptz, "ptz_config", {})
- zoom = ptz_config.get("default_zoom", 8)
- else:
- pan, tilt, zoom = self.ptz.calculate_ptz_position(x_ratio, y_ratio)
- # enforce PTZ command cooldown
- ptz_command_cooldown = self.config.get("ptz_command_cooldown", 0.2)
- with self._last_ptz_command_time_lock:
- elapsed = time.time() - self._last_ptz_command_time
- if elapsed < ptz_command_cooldown:
- time.sleep(ptz_command_cooldown - elapsed)
- success = self.ptz.goto_exact_position(pan, tilt, zoom)
- if not success:
- return None
- with self._last_ptz_command_time_lock:
- self._last_ptz_command_time = time.time()
- ptz_stabilize_time = self.config.get("ptz_stabilize_time", 2.0)
- time.sleep(max(ptz_stabilize_time, ptz_command_cooldown))
- ptz_frame = self._get_clear_ptz_frame()
- if ptz_frame is None:
- return None
- max_cap = self.config.get("max_capture_per_target", 0)
- with self._capture_counts_lock:
- if max_cap > 0 and self._capture_counts.get(target.track_id, 0) >= max_cap:
- return None
- self._capture_counts[target.track_id] = self._capture_counts.get(target.track_id, 0) + 1
- panorama_image = frame.copy() if self.config.get("save_panorama_pair", True) else None
- # 本地保存
- self._save_local(ptz_frame, panorama_image, target, pan, tilt, zoom)
- return CaptureRecord(
- track_id=target.track_id,
- timestamp=time.time(),
- position=(x_ratio, y_ratio),
- ptz_position=(pan, tilt, zoom),
- ptz_image=ptz_frame,
- panorama_image=panorama_image,
- confidence=target.confidence,
- )
- def _get_clear_ptz_frame(self, max_attempts: int = 5, wait_interval: float = 0.2) -> Optional[np.ndarray]:
- best_frame = None
- best_score = -1
- for _ in range(max_attempts):
- frame = self.ptz.get_frame()
- if frame is not None:
- frame_copy = frame.copy()
- gray = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2GRAY)
- score = cv2.Laplacian(gray, cv2.CV_64F).var()
- if score > best_score:
- best_score = score
- best_frame = frame_copy
- time.sleep(wait_interval)
- return best_frame
- def _save_local(self, ptz_frame, panorama_image, target, pan, tilt, zoom):
- capture_dir = self.config.get("capture_dir", "/home/admin/dsh/tracking_captures")
- try:
- os.makedirs(capture_dir, exist_ok=True)
- except OSError as e:
- logger.warning(f"无法创建抓拍目录 {capture_dir}: {e}")
- return
- timestamp = int(time.time() * 1000)
- base = f"{capture_dir}/ptz_{target.track_id}_{timestamp}_{pan:.0f}_{tilt:.0f}_z{zoom}.jpg"
- cv2.imwrite(base, ptz_frame)
- if panorama_image is not None:
- pan_base = f"{capture_dir}/panorama_{target.track_id}_{timestamp}.jpg"
- cv2.imwrite(pan_base, panorama_image)
- def _upload_batch(self, records: List[CaptureRecord]):
- if not self.event_pusher or not self.config.get("enable_upload", True):
- return
- try:
- uploads = []
- for r in records:
- ptz_url = self.event_pusher.upload_numpy_image(r.ptz_image)
- pan_url = None
- if r.panorama_image is not None:
- pan_url = self.event_pusher.upload_numpy_image(r.panorama_image)
- uploads.append({
- "track_id": r.track_id,
- "ptz_image_url": ptz_url,
- "panorama_image_url": pan_url,
- "position": r.position,
- "ptz_position": r.ptz_position,
- "confidence": r.confidence,
- "timestamp": r.timestamp,
- })
- self.event_pusher.push_tracking_capture(batch_time=time.time(), captures=uploads)
- self._update_stats("uploads")
- except Exception as e:
- logger.error(f"批量上传失败: {e}")
- def _flush_batch_if_needed(self):
- with self.batch_lock:
- if self.batch_captures:
- self._upload_batch(self.batch_captures)
- self.batch_captures.clear()
- def _update_stats(self, key: str, value: int = 1):
- with self.stats_lock:
- if key in self.stats:
- self.stats[key] += value
- def get_results(self) -> List[CaptureRecord]:
- """获取当前抓拍结果(兼容既有协调器接口)"""
- with self.batch_lock:
- return self.batch_captures.copy()
- def get_stats(self) -> Dict:
- with self.stats_lock:
- return self.stats.copy()
|