paired_image_saver.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. """
  2. 配对图片保存管理器
  3. 将全景检测图片和对应的球机聚焦图片保存到同一目录
  4. """
  5. import os
  6. import cv2
  7. import time
  8. import logging
  9. import threading
  10. from pathlib import Path
  11. from datetime import datetime
  12. from typing import Optional, List, Dict, Tuple
  13. from dataclasses import dataclass, field
  14. logger = logging.getLogger(__name__)
  15. @dataclass
  16. class PersonTrackingInfo:
  17. """人员跟踪信息"""
  18. track_id: int
  19. position: Tuple[float, float] # (x_ratio, y_ratio)
  20. bbox: Tuple[int, int, int, int] # (x1, y1, x2, y2)
  21. confidence: float
  22. ptz_position: Optional[Tuple[float, float, int]] = None # (pan, tilt, zoom)
  23. ptz_bbox: Optional[Tuple[int, int, int, int]] = None # 球机图中检测到的bbox (x1, y1, x2, y2)
  24. ptz_image_saved: bool = False
  25. ptz_image_path: Optional[str] = None
  26. @dataclass
  27. class DetectionBatch:
  28. """一批检测记录"""
  29. batch_id: str
  30. timestamp: float
  31. panorama_image: Optional[object] = None # numpy array
  32. panorama_path: Optional[str] = None
  33. persons: List[PersonTrackingInfo] = field(default_factory=list)
  34. total_persons: int = 0
  35. ptz_images_count: int = 0
  36. completed: bool = False
  37. class PairedImageSaver:
  38. """
  39. 配对图片保存管理器
  40. 功能:
  41. 1. 为每次全景检测创建批次目录
  42. 2. 保存全景标记图到批次目录
  43. 3. 为每个人员保存对应的球机聚焦图到同一目录
  44. 4. 支持时间窗口内的批量保存
  45. """
  46. def __init__(self, base_dir: str = '/home/admin/dsh/paired_images',
  47. time_window: float = 5.0, # 时间窗口(秒)
  48. max_batches: int = 100):
  49. """
  50. 初始化
  51. Args:
  52. base_dir: 基础保存目录
  53. time_window: 批次时间窗口(秒),同一窗口内的检测归为一批
  54. max_batches: 最大保留批次数量
  55. """
  56. self.base_dir = Path(base_dir)
  57. self.time_window = time_window
  58. self.max_batches = max_batches
  59. self._current_batch: Optional[DetectionBatch] = None
  60. self._batch_lock = threading.Lock()
  61. self._last_batch_time = 0.0
  62. # 统计信息
  63. self._stats = {
  64. 'total_batches': 0,
  65. 'total_persons': 0,
  66. 'total_ptz_images': 0
  67. }
  68. self._stats_lock = threading.Lock()
  69. # 确保目录存在
  70. self._ensure_base_dir()
  71. logger.info(f"[配对保存] 初始化完成: 目录={base_dir}, 时间窗口={time_window}s")
  72. def _ensure_base_dir(self):
  73. """确保基础目录存在"""
  74. try:
  75. self.base_dir.mkdir(parents=True, exist_ok=True)
  76. except Exception as e:
  77. logger.error(f"[配对保存] 创建目录失败: {e}")
  78. def _generate_batch_id(self) -> str:
  79. """生成批次ID"""
  80. return datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
  81. def _create_batch_dir(self, batch_id: str) -> Path:
  82. """创建批次目录"""
  83. batch_dir = self.base_dir / f"batch_{batch_id}"
  84. try:
  85. batch_dir.mkdir(parents=True, exist_ok=True)
  86. return batch_dir
  87. except Exception as e:
  88. logger.error(f"[配对保存] 创建批次目录失败: {e}")
  89. return self.base_dir
  90. def start_new_batch(self, panorama_frame, persons: List[Dict]) -> Optional[str]:
  91. """
  92. 开始新批次
  93. Args:
  94. panorama_frame: 全景帧图像
  95. persons: 人员列表,每项包含 track_id, position, bbox, confidence
  96. Returns:
  97. batch_id: 批次ID,失败返回 None
  98. """
  99. with self._batch_lock:
  100. current_time = time.time()
  101. # 完成上一批次(如果有)
  102. # 注意:每次检测都创建独立批次,不复用,确保 batch_info 与实际检测一致
  103. if self._current_batch is not None:
  104. self._finalize_batch(self._current_batch)
  105. # 创建新批次
  106. batch_id = self._generate_batch_id()
  107. batch_dir = self._create_batch_dir(batch_id)
  108. # 保存全景图片
  109. panorama_path = None
  110. if panorama_frame is not None:
  111. panorama_path = self._save_panorama_image(
  112. batch_dir, batch_id, panorama_frame, persons
  113. )
  114. # 创建人员跟踪信息
  115. person_infos = []
  116. for i, p in enumerate(persons):
  117. info = PersonTrackingInfo(
  118. track_id=p.get('track_id', i),
  119. position=p.get('position', (0, 0)),
  120. bbox=p.get('bbox', (0, 0, 0, 0)),
  121. confidence=p.get('confidence', 0.0)
  122. )
  123. person_infos.append(info)
  124. # 创建批次记录
  125. self._current_batch = DetectionBatch(
  126. batch_id=batch_id,
  127. timestamp=current_time,
  128. panorama_image=panorama_frame,
  129. panorama_path=panorama_path,
  130. persons=person_infos,
  131. total_persons=len(persons)
  132. )
  133. self._last_batch_time = current_time
  134. with self._stats_lock:
  135. self._stats['total_batches'] += 1
  136. self._stats['total_persons'] += len(persons)
  137. logger.info(
  138. f"[配对保存] 新批次创建: {batch_id}, "
  139. f"人员={len(persons)}, 目录={batch_dir}"
  140. )
  141. return batch_id
  142. def _save_panorama_image(self, batch_dir: Path, batch_id: str,
  143. frame, persons: List[Dict]) -> Optional[str]:
  144. """
  145. 保存全景标记图片
  146. Args:
  147. batch_dir: 批次目录
  148. batch_id: 批次ID
  149. frame: 全景帧
  150. persons: 人员列表(已由调用方过滤,此处不再过滤)
  151. Returns:
  152. 保存路径或 None
  153. """
  154. try:
  155. # 复制图像避免修改原图
  156. marked_frame = frame.copy()
  157. # 绘制每个人员的标记(使用连续的序号)
  158. # 注意:persons 已由调用方(coordinator)过滤,置信度均 >= 阈值
  159. for i, person in enumerate(persons):
  160. bbox = person.get('bbox', (0, 0, 0, 0))
  161. x1, y1, x2, y2 = bbox
  162. conf = person.get('confidence', 0.0)
  163. # 绘制边界框(绿色)
  164. cv2.rectangle(marked_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
  165. # 绘制序号标签(带置信度)
  166. label = f"person_{i}({conf:.2f})"
  167. (label_w, label_h), baseline = cv2.getTextSize(
  168. label, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2
  169. )
  170. # 标签背景
  171. cv2.rectangle(
  172. marked_frame,
  173. (x1, y1 - label_h - 8),
  174. (x1 + label_w, y1),
  175. (0, 255, 0),
  176. -1
  177. )
  178. # 标签文字(黑色)
  179. cv2.putText(
  180. marked_frame, label,
  181. (x1, y1 - 4),
  182. cv2.FONT_HERSHEY_SIMPLEX, 0.8,
  183. (0, 0, 0), 2
  184. )
  185. # 保存图片(使用人员数量)
  186. filename = f"00_panorama_n{len(persons)}.jpg"
  187. filepath = batch_dir / filename
  188. cv2.imwrite(str(filepath), marked_frame, [cv2.IMWRITE_JPEG_QUALITY, 90])
  189. logger.info(f"[配对保存] 全景图已保存: {filepath},人员数量 {len(persons)}")
  190. return str(filepath)
  191. except Exception as e:
  192. logger.error(f"[配对保存] 保存全景图失败: {e}")
  193. return None
  194. def save_ptz_image(self, batch_id: str, person_index: int,
  195. ptz_frame, ptz_position: Tuple[float, float, int],
  196. ptz_bbox: Tuple[int, int, int, int] = None,
  197. person_info: Dict = None) -> Optional[str]:
  198. """
  199. 保存球机聚焦图片
  200. Args:
  201. batch_id: 批次ID
  202. person_index: 人员序号(0-based)
  203. ptz_frame: 球机帧
  204. ptz_position: PTZ位置 (pan, tilt, zoom)
  205. ptz_bbox: 球机图中检测到的bbox (x1, y1, x2, y2)
  206. person_info: 额外人员信息
  207. Returns:
  208. 保存路径或 None
  209. """
  210. with self._batch_lock:
  211. if self._current_batch is None or self._current_batch.batch_id != batch_id:
  212. logger.warning(f"[配对保存] 批次不存在或已过期: {batch_id}")
  213. return None
  214. batch_dir = self.base_dir / f"batch_{batch_id}"
  215. try:
  216. # 复制图像
  217. marked_frame = ptz_frame.copy() if ptz_frame is not None else None
  218. if marked_frame is not None:
  219. # 在球机图上添加标记(如果有检测框)
  220. h, w = marked_frame.shape[:2]
  221. # 添加PTZ位置信息到图片
  222. pan, tilt, zoom = ptz_position
  223. info_text = f"PTZ: P={pan:.1f} T={tilt:.1f} Z={zoom}"
  224. cv2.putText(
  225. marked_frame, info_text,
  226. (10, 30),
  227. cv2.FONT_HERSHEY_SIMPLEX, 0.7,
  228. (0, 255, 0), 2
  229. )
  230. # 添加人员序号
  231. person_text = f"person_{person_index}"
  232. cv2.putText(
  233. marked_frame, person_text,
  234. (10, 60),
  235. cv2.FONT_HERSHEY_SIMPLEX, 0.7,
  236. (0, 255, 0), 2
  237. )
  238. # 绘制PTZ检测到的bbox(红色)
  239. if ptz_bbox is not None:
  240. x1, y1, x2, y2 = ptz_bbox
  241. cv2.rectangle(marked_frame, (x1, y1), (x2, y2), (0, 0, 255), 2)
  242. bbox_text = f"PTZ_BBox: ({x1},{y1},{x2},{y2})"
  243. cv2.putText(
  244. marked_frame, bbox_text,
  245. (10, 90),
  246. cv2.FONT_HERSHEY_SIMPLEX, 0.6,
  247. (0, 0, 255), 2
  248. )
  249. # 保存图片
  250. filename = f"01_ptz_person{person_index}_p{int(ptz_position[0])}_t{int(ptz_position[1])}_z{int(ptz_position[2])}.jpg"
  251. filepath = batch_dir / filename
  252. if marked_frame is not None:
  253. cv2.imwrite(str(filepath), marked_frame, [cv2.IMWRITE_JPEG_QUALITY, 90])
  254. # 更新批次信息
  255. if person_index < len(self._current_batch.persons):
  256. self._current_batch.persons[person_index].ptz_position = ptz_position
  257. self._current_batch.persons[person_index].ptz_bbox = ptz_bbox
  258. self._current_batch.persons[person_index].ptz_image_saved = True
  259. self._current_batch.persons[person_index].ptz_image_path = str(filepath)
  260. self._current_batch.ptz_images_count += 1
  261. with self._stats_lock:
  262. self._stats['total_ptz_images'] += 1
  263. logger.info(f"[配对保存] 球机图已保存: {filepath}, BBox={ptz_bbox}")
  264. return str(filepath)
  265. except Exception as e:
  266. logger.error(f"[配对保存] 保存球机图失败: {e}")
  267. return None
  268. def _finalize_batch(self, batch: DetectionBatch):
  269. """完成批次处理"""
  270. batch.completed = True
  271. # 创建批次信息文件
  272. try:
  273. batch_dir = self.base_dir / f"batch_{batch.batch_id}"
  274. info_path = batch_dir / "batch_info.txt"
  275. with open(info_path, 'w', encoding='utf-8') as f:
  276. f.write(f"批次ID: {batch.batch_id}\n")
  277. f.write(f"时间戳: {datetime.fromtimestamp(batch.timestamp)}\n")
  278. f.write(f"总人数: {batch.total_persons}\n")
  279. f.write(f"球机图数量: {batch.ptz_images_count}\n")
  280. f.write(f"全景图: {batch.panorama_path}\n")
  281. f.write("\n人员详情:\n")
  282. for i, person in enumerate(batch.persons):
  283. f.write(f"\n Person {i}:\n")
  284. f.write(f" Track ID: {person.track_id}\n")
  285. f.write(f" Position: ({person.position[0]:.3f}, {person.position[1]:.3f})\n")
  286. f.write(f" BBox: ({person.bbox[0]}, {person.bbox[1]}, {person.bbox[2]}, {person.bbox[3]})\n")
  287. f.write(f" Confidence: {person.confidence:.2f}\n")
  288. f.write(f" PTZ Position: {person.ptz_position}\n")
  289. if person.ptz_bbox:
  290. f.write(f" PTZ BBox: ({person.ptz_bbox[0]}, {person.ptz_bbox[1]}, {person.ptz_bbox[2]}, {person.ptz_bbox[3]})\n")
  291. else:
  292. f.write(f" PTZ BBox: None\n")
  293. f.write(f" PTZ Image: {person.ptz_image_path}\n")
  294. logger.info(f"[配对保存] 批次完成: {batch.batch_id}, "
  295. f"人员={batch.total_persons}, 球机图={batch.ptz_images_count}")
  296. except Exception as e:
  297. logger.error(f"[配对保存] 保存批次信息失败: {e}")
  298. # 清理旧批次
  299. self._cleanup_old_batches()
  300. def _cleanup_old_batches(self):
  301. """清理旧批次目录"""
  302. try:
  303. batch_dirs = sorted(
  304. [d for d in self.base_dir.iterdir() if d.is_dir() and d.name.startswith('batch_')],
  305. key=lambda x: x.stat().st_mtime
  306. )
  307. if len(batch_dirs) > self.max_batches:
  308. to_delete = batch_dirs[:len(batch_dirs) - self.max_batches]
  309. for d in to_delete:
  310. import shutil
  311. shutil.rmtree(d)
  312. logger.info(f"[配对保存] 清理旧批次: {d.name}")
  313. except Exception as e:
  314. logger.error(f"[配对保存] 清理旧批次失败: {e}")
  315. def get_current_batch_id(self) -> Optional[str]:
  316. """获取当前批次ID"""
  317. with self._batch_lock:
  318. return self._current_batch.batch_id if self._current_batch else None
  319. def get_stats(self) -> Dict:
  320. """获取统计信息"""
  321. with self._stats_lock:
  322. return self._stats.copy()
  323. def close(self):
  324. """关闭管理器,完成当前批次"""
  325. with self._batch_lock:
  326. if self._current_batch is not None:
  327. self._finalize_batch(self._current_batch)
  328. self._current_batch = None
  329. logger.info("[配对保存] 管理器已关闭")
  330. # 全局单例实例
  331. _paired_saver_instance: Optional[PairedImageSaver] = None
  332. def get_paired_saver(base_dir: str = None, time_window: float = 5.0) -> PairedImageSaver:
  333. """
  334. 获取配对保存管理器实例(单例模式)
  335. Args:
  336. base_dir: 基础保存目录
  337. time_window: 时间窗口
  338. Returns:
  339. PairedImageSaver 实例
  340. """
  341. global _paired_saver_instance
  342. if _paired_saver_instance is None:
  343. _paired_saver_instance = PairedImageSaver(
  344. base_dir=base_dir or '/home/admin/dsh/paired_images',
  345. time_window=time_window
  346. )
  347. return _paired_saver_instance
  348. def reset_paired_saver():
  349. """重置单例实例(用于测试)"""
  350. global _paired_saver_instance
  351. if _paired_saver_instance is not None:
  352. _paired_saver_instance.close()
  353. _paired_saver_instance = None