panorama_camera.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921
  1. """
  2. 全景摄像头模块
  3. 负责获取视频流和物体检测
  4. """
  5. import os
  6. # 必须在导入cv2之前设置,防止FFmpeg多线程解码崩溃
  7. # pthread_frame.c:167 async_lock assertion
  8. os.environ['OPENCV_FFMPEG_CAPTURE_OPTIONS'] = 'threads;1'
  9. import cv2
  10. import numpy as np
  11. import threading
  12. import queue
  13. import time
  14. import logging
  15. from datetime import datetime
  16. from typing import Optional, List, Tuple, Dict, Any
  17. from dataclasses import dataclass
  18. from pathlib import Path
  19. from config import PANORAMA_CAMERA, DETECTION_CONFIG
  20. from dahua_sdk import DahuaSDK, PTZCommand
  21. from video_lock import safe_read, safe_is_opened
  22. logger = logging.getLogger(__name__)
  23. @dataclass
  24. class DetectedObject:
  25. """检测到的物体"""
  26. class_name: str # 类别名称
  27. confidence: float # 置信度
  28. bbox: Tuple[int, int, int, int] # 边界框 (x, y, width, height)
  29. center: Tuple[int, int] # 中心点坐标
  30. track_id: Optional[int] = None # 跟踪ID
  31. class PanoramaCamera:
  32. """全景摄像头类"""
  33. def __init__(self, sdk: DahuaSDK, camera_config: Dict = None):
  34. """
  35. 初始化全景摄像头
  36. Args:
  37. sdk: 大华SDK实例
  38. camera_config: 摄像头配置
  39. """
  40. self.sdk = sdk
  41. self.config = camera_config or PANORAMA_CAMERA
  42. self.login_handle = None
  43. self.play_handle = None
  44. self.connected = False
  45. # 视频流
  46. self.frame_queue = queue.Queue(maxsize=10)
  47. self.current_frame = None
  48. self.frame_lock = threading.Lock()
  49. self.rtsp_cap = None # RTSP视频捕获
  50. self._camera_id = 'panorama' # 用于per-camera锁
  51. # 检测器
  52. self.detector = None
  53. # 控制标志
  54. self.running = False
  55. self.stream_thread = None
  56. # 断线重连
  57. self.auto_reconnect = True
  58. self.reconnect_interval = 5.0 # 重连间隔(秒)
  59. self.max_reconnect_attempts = 3 # 最大重连次数
  60. def connect(self) -> bool:
  61. """
  62. 连接摄像头
  63. Returns:
  64. 是否成功
  65. """
  66. login_handle, error = self.sdk.login(
  67. self.config['ip'],
  68. self.config['port'],
  69. self.config['username'],
  70. self.config['password']
  71. )
  72. if login_handle is None:
  73. print(f"连接全景摄像头失败: IP={self.config['ip']}, 错误码={error}")
  74. return False
  75. self.login_handle = login_handle
  76. self.connected = True
  77. print(f"成功连接全景摄像头: {self.config['ip']}")
  78. return True
  79. def disconnect(self):
  80. """断开连接"""
  81. self.stop_stream()
  82. if self.login_handle:
  83. self.sdk.logout(self.login_handle)
  84. self.login_handle = None
  85. self.connected = False
  86. def start_stream(self) -> bool:
  87. """
  88. 开始视频流
  89. Returns:
  90. 是否成功
  91. """
  92. if not self.connected:
  93. return False
  94. self.play_handle = self.sdk.real_play(
  95. self.login_handle,
  96. self.config['channel']
  97. )
  98. if self.play_handle is None:
  99. print("启动视频流失败")
  100. return False
  101. self.running = True
  102. self.stream_thread = threading.Thread(target=self._stream_worker, daemon=True)
  103. self.stream_thread.start()
  104. print("视频流已启动")
  105. return True
  106. def start_stream_rtsp(self, rtsp_url: str = None) -> bool:
  107. if rtsp_url is None:
  108. rtsp_url = self.config.get('rtsp_url') or f"rtsp://{self.config['username']}:{self.config['password']}@{self.config['ip']}:{self.config.get('rtsp_port', 554)}/h264/ch{self.config['channel']}/main/av_stream"
  109. try:
  110. # 先尝试FFmpeg后端
  111. self.rtsp_cap = cv2.VideoCapture(rtsp_url, cv2.CAP_FFMPEG)
  112. if not self.rtsp_cap.isOpened():
  113. # FFmpeg失败,尝试GStreamer后端
  114. print(f"FFmpeg后端无法打开RTSP流,尝试GStreamer后端...")
  115. try:
  116. gst_cap = cv2.VideoCapture(rtsp_url, cv2.CAP_GSTREAMER)
  117. if gst_cap.isOpened():
  118. self.rtsp_cap = gst_cap
  119. print(f"使用GStreamer后端打开RTSP流成功")
  120. else:
  121. print(f"无法打开RTSP流: {rtsp_url}")
  122. return False
  123. except Exception as ge:
  124. print(f"GStreamer后端也不可用: {ge}")
  125. return False
  126. self.rtsp_cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
  127. self.running = True
  128. self.stream_thread = threading.Thread(target=self._rtsp_stream_worker, daemon=True)
  129. self.stream_thread.start()
  130. print(f"RTSP视频流已启动: {rtsp_url}")
  131. return True
  132. except Exception as e:
  133. print(f"RTSP流启动失败: {e}")
  134. return False
  135. def _stream_worker(self):
  136. """视频流工作线程 (SDK模式)"""
  137. retry_count = 0
  138. max_retries = 10
  139. while self.running:
  140. try:
  141. # 尝试从 SDK 帧缓冲区获取帧 (如果可用)
  142. frame_buffer = self.sdk.get_video_frame_buffer(self.config['channel'])
  143. if frame_buffer:
  144. frame_info = frame_buffer.get(timeout=0.1)
  145. if frame_info and frame_info.get('data'):
  146. # 解码帧数据 (如果需要)
  147. # 注意: SDK回调返回的是编码数据,需要解码
  148. # 这里暂时跳过,因为解码需要额外处理
  149. pass
  150. # RTSP 模式获取帧 (推荐方式)
  151. if self.rtsp_cap is not None and safe_is_opened(self.rtsp_cap, self._camera_id):
  152. ret, frame = safe_read(self.rtsp_cap, self._camera_id)
  153. if ret and frame is not None:
  154. with self.frame_lock:
  155. self.current_frame = frame.copy()
  156. try:
  157. self.frame_queue.put(frame.copy(), block=False)
  158. except queue.Full:
  159. pass
  160. retry_count = 0 # 重置重试计数
  161. time.sleep(0.001) # 减少CPU占用
  162. continue
  163. # 如果 RTSP 不可用,尝试自动连接
  164. if retry_count < max_retries:
  165. rtsp_url = self._build_rtsp_url()
  166. try:
  167. if self.rtsp_cap is None:
  168. self.rtsp_cap = cv2.VideoCapture(rtsp_url, cv2.CAP_FFMPEG)
  169. self.rtsp_cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) # 减少缓冲延迟
  170. if safe_is_opened(self.rtsp_cap, self._camera_id):
  171. retry_count = 0
  172. continue
  173. except Exception as e:
  174. pass
  175. retry_count += 1
  176. time.sleep(1.0) # 重试间隔
  177. else:
  178. # 超过最大重试次数,使用模拟帧
  179. frame = np.zeros((1080, 1920, 3), dtype=np.uint8)
  180. with self.frame_lock:
  181. self.current_frame = frame
  182. try:
  183. self.frame_queue.put(frame, block=False)
  184. except queue.Full:
  185. pass
  186. time.sleep(0.1)
  187. except Exception as e:
  188. err_str = str(e)
  189. if 'async_lock' in err_str or 'Assertion' in err_str:
  190. print(f"视频流FFmpeg内部错误,重建连接: {e}")
  191. self._reconnect_rtsp()
  192. else:
  193. print(f"视频流错误: {e}")
  194. time.sleep(0.5)
  195. def _build_rtsp_url(self) -> str:
  196. return self.config.get('rtsp_url') or f"rtsp://{self.config['username']}:{self.config['password']}@{self.config['ip']}:{self.config.get('rtsp_port', 554)}/h264/ch{self.config['channel']}/main/av_stream"
  197. def _rtsp_stream_worker(self):
  198. """RTSP视频流工作线程"""
  199. import signal
  200. # 屏蔽SIGINT在此线程,由主线程处理
  201. if hasattr(signal, 'pthread_sigmask'):
  202. try:
  203. signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGINT})
  204. except (AttributeError, OSError):
  205. pass
  206. max_consecutive_errors = 50
  207. error_count = 0
  208. while self.running:
  209. try:
  210. if self.rtsp_cap is None or not safe_is_opened(self.rtsp_cap, self._camera_id):
  211. time.sleep(0.1)
  212. continue
  213. ret, frame = safe_read(self.rtsp_cap, self._camera_id)
  214. if not ret or frame is None:
  215. error_count += 1
  216. if error_count > max_consecutive_errors:
  217. print(f"全景RTSP流连续{max_consecutive_errors}次读取失败,尝试重连...")
  218. self._reconnect_rtsp()
  219. error_count = 0
  220. time.sleep(0.01)
  221. continue
  222. error_count = 0
  223. with self.frame_lock:
  224. self.current_frame = frame.copy()
  225. try:
  226. self.frame_queue.put(frame, block=False)
  227. except queue.Full:
  228. pass
  229. except Exception as e:
  230. err_str = str(e)
  231. if 'async_lock' in err_str or 'Assertion' in err_str:
  232. print(f"全景RTSP流FFmpeg内部错误,3秒后重建连接: {e}")
  233. time.sleep(3)
  234. self._reconnect_rtsp()
  235. else:
  236. print(f"全景RTSP视频流错误: {e}")
  237. time.sleep(0.5)
  238. def _reconnect_rtsp(self):
  239. """重建RTSP连接"""
  240. rtsp_url = self._build_rtsp_url()
  241. if self.rtsp_cap is not None:
  242. try:
  243. self.rtsp_cap.release()
  244. except Exception:
  245. pass
  246. self.rtsp_cap = None
  247. time.sleep(1)
  248. try:
  249. self.rtsp_cap = cv2.VideoCapture(rtsp_url, cv2.CAP_FFMPEG)
  250. if safe_is_opened(self.rtsp_cap, self._camera_id):
  251. self.rtsp_cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
  252. print("全景RTSP流重连成功")
  253. else:
  254. print("全景RTSP流重连失败")
  255. self.rtsp_cap = None
  256. except Exception as e:
  257. print(f"全景RTSP流重连异常: {e}")
  258. self.rtsp_cap = None
  259. def stop_stream(self):
  260. """停止视频流"""
  261. self.running = False
  262. if self.stream_thread:
  263. self.stream_thread.join(timeout=2)
  264. if self.play_handle:
  265. self.sdk.stop_real_play(self.play_handle)
  266. self.play_handle = None
  267. if self.rtsp_cap:
  268. self.rtsp_cap.release()
  269. self.rtsp_cap = None
  270. def get_frame(self) -> Optional[np.ndarray]:
  271. """
  272. 获取当前帧
  273. Returns:
  274. 当前帧图像
  275. """
  276. with self.frame_lock:
  277. return self.current_frame.copy() if self.current_frame is not None else None
  278. def get_frame_from_queue(self, timeout: float = 0.1) -> Optional[np.ndarray]:
  279. """
  280. 从帧队列获取帧 (用于批量处理)
  281. Args:
  282. timeout: 等待超时时间
  283. Returns:
  284. 帧图像或None
  285. """
  286. try:
  287. return self.frame_queue.get(timeout=timeout)
  288. except:
  289. return None
  290. def get_frame_buffer(self, count: int = 5) -> List[np.ndarray]:
  291. """
  292. 获取帧缓冲 (用于运动检测等需要多帧的场景)
  293. Args:
  294. count: 获取帧数
  295. Returns:
  296. 帧列表
  297. """
  298. frames = []
  299. while len(frames) < count:
  300. frame = self.get_frame_from_queue(timeout=0.05)
  301. if frame is not None:
  302. frames.append(frame)
  303. else:
  304. break
  305. return frames
  306. def set_detector(self, detector):
  307. """设置物体检测器"""
  308. self.detector = detector
  309. def detect_objects(self, frame: np.ndarray = None) -> List[DetectedObject]:
  310. """
  311. 检测物体
  312. Args:
  313. frame: 输入帧,如果为None则使用当前帧
  314. Returns:
  315. 检测到的物体列表
  316. """
  317. if frame is None:
  318. frame = self.get_frame()
  319. if frame is None or self.detector is None:
  320. return []
  321. return self.detector.detect(frame)
  322. def get_detection_position(self, obj: DetectedObject,
  323. frame_size: Tuple[int, int]) -> Tuple[float, float]:
  324. """
  325. 获取检测物体在画面中的相对位置
  326. Args:
  327. obj: 检测到的物体
  328. frame_size: 画面尺寸 (width, height)
  329. Returns:
  330. 相对位置 (x_ratio, y_ratio) 范围0-1
  331. """
  332. width, height = frame_size
  333. x_ratio = obj.center[0] / width
  334. y_ratio = obj.center[1] / height
  335. return (x_ratio, y_ratio)
  336. class ObjectDetector:
  337. """
  338. 物体检测器
  339. 使用YOLO11模型进行人体检测
  340. 支持 YOLO (.pt), RKNN (.rknn), ONNX (.onnx) 模型
  341. """
  342. def __init__(self, model_path: str = None, use_gpu: bool = True, model_size: str = 'n',
  343. model_type: str = 'auto'):
  344. """
  345. 初始化检测器
  346. Args:
  347. model_path: 模型路径 (支持 .pt, .rknn, .onnx)
  348. use_gpu: 是否使用GPU
  349. model_size: 模型尺寸 ('n', 's', 'm', 'l', 'x') - 仅 YOLO 模型有效
  350. model_type: 模型类型 ('auto', 'yolo', 'rknn', 'onnx')
  351. """
  352. self.model = None
  353. self.rknn_detector = None
  354. self.model_path = model_path
  355. self.use_gpu = use_gpu
  356. self.model_size = model_size
  357. self.model_type = model_type
  358. self.config = DETECTION_CONFIG
  359. self.device = 'cuda:0' if use_gpu else 'cpu'
  360. # 检测图片保存配置
  361. self._save_image_enabled = self.config.get('save_detection_image', False)
  362. self._image_save_dir = Path(self.config.get('detection_image_dir', './detection_images'))
  363. self._image_max_count = self.config.get('detection_image_max_count', 1000)
  364. self._last_save_time = 0
  365. self._save_interval = 1.0 # 最小保存间隔(秒),避免保存过于频繁
  366. # 创建保存目录
  367. if self._save_image_enabled:
  368. self._ensure_save_dir()
  369. # 根据扩展名自动判断模型类型
  370. if model_path:
  371. ext = os.path.splitext(model_path)[1].lower()
  372. if ext == '.rknn':
  373. self.model_type = 'rknn'
  374. elif ext == '.onnx':
  375. self.model_type = 'onnx'
  376. elif ext == '.pt':
  377. self.model_type = 'yolo'
  378. self._load_model()
  379. def _load_model(self):
  380. """加载检测模型"""
  381. if self.model_type == 'rknn':
  382. self._load_rknn_model()
  383. elif self.model_type == 'onnx':
  384. self._load_onnx_model()
  385. else:
  386. self._load_yolo_model()
  387. def _load_rknn_model(self):
  388. """加载 RKNN 模型"""
  389. if not self.model_path:
  390. raise ValueError("RKNN 模型需要指定 model_path")
  391. try:
  392. from rknnlite.api import RKNNLite
  393. self.rknn = RKNNLite()
  394. ret = self.rknn.load_rknn(self.model_path)
  395. if ret != 0:
  396. raise RuntimeError(f"加载 RKNN 模型失败: {self.model_path}")
  397. ret = self.rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0_1_2)
  398. if ret != 0:
  399. raise RuntimeError(f"初始化 RKNN 运行时失败")
  400. print(f"RKNN 模型加载成功: {self.model_path}")
  401. except ImportError:
  402. raise ImportError("未安装 rknnlite,请运行: pip install rknnlite2")
  403. def _load_onnx_model(self):
  404. """加载 ONNX 模型"""
  405. if not self.model_path:
  406. raise ValueError("ONNX 模型需要指定 model_path")
  407. try:
  408. import onnxruntime as ort
  409. self.session = ort.InferenceSession(self.model_path)
  410. self.input_name = self.session.get_inputs()[0].name
  411. self.output_name = self.session.get_outputs()[0].name
  412. print(f"ONNX 模型加载成功: {self.model_path}")
  413. except ImportError:
  414. raise ImportError("未安装 onnxruntime,请运行: pip install onnxruntime")
  415. def _load_yolo_model(self):
  416. """加载YOLO11检测模型"""
  417. try:
  418. from ultralytics import YOLO
  419. if self.model_path:
  420. self.model = YOLO(self.model_path)
  421. else:
  422. model_name = f'yolo11{self.model_size}.pt'
  423. self.model = YOLO(model_name)
  424. dummy = np.zeros((640, 640, 3), dtype=np.uint8)
  425. self.model(dummy, device=self.device, verbose=False)
  426. print(f"成功加载YOLO11检测模型 (device={self.device})")
  427. except ImportError:
  428. print("未安装ultralytics,请运行: pip install ultralytics")
  429. self._load_opencv_model()
  430. except Exception as e:
  431. print(f"加载YOLO11模型失败: {e}")
  432. self._load_opencv_model()
  433. def _load_opencv_model(self):
  434. """使用OpenCV加载模型"""
  435. pass
  436. def _ensure_save_dir(self):
  437. """确保保存目录存在"""
  438. try:
  439. self._image_save_dir.mkdir(parents=True, exist_ok=True)
  440. logger.info(f"检测图片保存目录: {self._image_save_dir}")
  441. except Exception as e:
  442. logger.error(f"创建检测图片目录失败: {e}")
  443. self._save_image_enabled = False
  444. def _cleanup_old_images(self):
  445. """清理旧图片,保持目录下图片数量不超过上限"""
  446. try:
  447. image_files = list(self._image_save_dir.glob("*.jpg"))
  448. if len(image_files) > self._image_max_count:
  449. # 按修改时间排序,删除最旧的
  450. image_files.sort(key=lambda x: x.stat().st_mtime)
  451. to_delete = image_files[:len(image_files) - self._image_max_count]
  452. for f in to_delete:
  453. f.unlink()
  454. logger.info(f"已清理 {len(to_delete)} 张旧检测图片")
  455. except Exception as e:
  456. logger.error(f"清理旧图片失败: {e}")
  457. def _save_detection_image(self, frame: np.ndarray, detections: List[DetectedObject]):
  458. """
  459. 保存带有检测标记的图片(只标记达到置信度阈值的人)
  460. Args:
  461. frame: 原始图像
  462. detections: 检测结果列表
  463. """
  464. if not self._save_image_enabled or not detections:
  465. return
  466. # 检查保存间隔
  467. current_time = time.time()
  468. if current_time - self._last_save_time < self._save_interval:
  469. return
  470. try:
  471. # 复制图像避免修改原图
  472. marked_frame = frame.copy()
  473. # 置信度阈值(人员检测用更高阈值)
  474. person_threshold = self.config.get('person_threshold', 0.8)
  475. # 只标记达到阈值的人
  476. person_count = 0
  477. for det in detections:
  478. # 只处理人且达到阈值
  479. is_person = det.class_name in ['person']
  480. if not is_person:
  481. continue
  482. # 未达阈值的不标记
  483. if det.confidence < person_threshold:
  484. continue
  485. x, y, w, h = det.bbox
  486. # 绘制边界框(绿色)
  487. cv2.rectangle(marked_frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
  488. # 绘制序号标签
  489. label = f"person_{person_count}"
  490. person_count += 1
  491. (label_w, label_h), baseline = cv2.getTextSize(
  492. label, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2
  493. )
  494. cv2.rectangle(
  495. marked_frame,
  496. (x, y - label_h - 8),
  497. (x + label_w, y),
  498. (0, 255, 0),
  499. -1
  500. )
  501. # 绘制标签文字(黑色)
  502. cv2.putText(
  503. marked_frame, label,
  504. (x, y - 4),
  505. cv2.FONT_HERSHEY_SIMPLEX, 0.8,
  506. (0, 0, 0), 2
  507. )
  508. # 无有效目标则不保存
  509. if person_count == 0:
  510. return
  511. # 生成文件名(时间戳+有效人数)
  512. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
  513. filename = f"panorama_{timestamp}_n{person_count}.jpg"
  514. filepath = self._image_save_dir / filename
  515. # 保存图片
  516. cv2.imwrite(str(filepath), marked_frame, [cv2.IMWRITE_JPEG_QUALITY, 90])
  517. self._last_save_time = current_time
  518. logger.info(f"[全景] 已保存检测图片: {filepath},有效人数 {person_count} (阈值={person_threshold})")
  519. # 定期清理旧图片
  520. self._cleanup_old_images()
  521. except Exception as e:
  522. logger.error(f"[全景] 保存检测图片失败: {e}")
  523. def _letterbox(self, image, size=(640, 640)):
  524. """Letterbox 预处理"""
  525. h0, w0 = image.shape[:2]
  526. ih, iw = size
  527. scale = min(iw / w0, ih / h0)
  528. new_w, new_h = int(w0 * scale), int(h0 * scale)
  529. pad_w = (iw - new_w) // 2
  530. pad_h = (ih - new_h) // 2
  531. resized = cv2.resize(image, (new_w, new_h))
  532. canvas = np.full((ih, iw, 3), 114, dtype=np.uint8)
  533. canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
  534. return canvas, scale, pad_w, pad_h, h0, w0
  535. def _detect_rknn(self, frame: np.ndarray) -> List[DetectedObject]:
  536. """使用 RKNN/ONNX 模型检测"""
  537. results = []
  538. try:
  539. canvas, scale, pad_w, pad_h, h0, w0 = self._letterbox(frame)
  540. if hasattr(self, 'rknn'):
  541. # RKNN
  542. img = canvas[..., ::-1].astype(np.float32) / 255.0
  543. blob = img[None, ...]
  544. outputs = self.rknn.inference(inputs=[blob])
  545. else:
  546. # ONNX
  547. img = canvas[..., ::-1].astype(np.float32) / 255.0
  548. img = img.transpose(2, 0, 1)
  549. blob = img[None, ...]
  550. outputs = self.session.run([self.output_name], {self.input_name: blob})
  551. output = outputs[0]
  552. if len(output.shape) == 3:
  553. output = output[0]
  554. num_boxes = output.shape[1]
  555. conf_threshold = self.config['confidence_threshold']
  556. for i in range(num_boxes):
  557. x_center = float(output[0, i])
  558. y_center = float(output[1, i])
  559. width = float(output[2, i])
  560. height = float(output[3, i])
  561. class_probs = output[4:, i]
  562. best_class = int(np.argmax(class_probs))
  563. confidence = float(class_probs[best_class])
  564. if confidence < conf_threshold:
  565. continue
  566. # 转换到原始图像坐标
  567. x1 = int(((x_center - width / 2) - pad_w) / scale)
  568. y1 = int(((y_center - height / 2) - pad_h) / scale)
  569. x2 = int(((x_center + width / 2) - pad_w) / scale)
  570. y2 = int(((y_center + height / 2) - pad_h) / scale)
  571. x1 = max(0, min(w0, x1))
  572. y1 = max(0, min(h0, y1))
  573. x2 = max(0, min(w0, x2))
  574. y2 = max(0, min(h0, y2))
  575. if x2 - x1 < 10 or y2 - y1 < 10:
  576. continue
  577. # 使用配置的类别映射获取类别名称
  578. class_map = self.config.get('class_map', {0: 'person', 3: '人'})
  579. cls_name = class_map.get(best_class, str(best_class))
  580. # 检查是否为目标类别
  581. if cls_name not in self.config['target_classes']:
  582. continue
  583. obj = DetectedObject(
  584. class_name=cls_name,
  585. confidence=confidence,
  586. bbox=(x1, y1, x2 - x1, y2 - y1),
  587. center=((x1 + x2) // 2, (y1 + y2) // 2)
  588. )
  589. results.append(obj)
  590. except Exception as e:
  591. logger.error(f"RKNN/ONNX 检测错误: {e}")
  592. return results
  593. def detect(self, frame: np.ndarray) -> List[DetectedObject]:
  594. """检测物体返回所有类别结果"""
  595. if frame is None:
  596. return []
  597. if hasattr(self, 'rknn') and self.rknn is not None:
  598. results = self._detect_rknn(frame)
  599. if results:
  600. self._log_detections("RKNN", results, frame)
  601. self._save_detection_image(frame, results)
  602. return results
  603. elif hasattr(self, 'session') and self.session is not None:
  604. results = self._detect_rknn(frame)
  605. if results:
  606. self._log_detections("ONNX", results, frame)
  607. self._save_detection_image(frame, results)
  608. return results
  609. elif self.model is not None:
  610. results = self._detect_yolo(frame)
  611. if results:
  612. self._log_detections("YOLO", results, frame)
  613. self._save_detection_image(frame, results)
  614. return results
  615. else:
  616. logger.error("[YOLO] 没有可用的检测模型")
  617. return []
  618. def _log_detections(self, model_type: str, results: List[DetectedObject], frame: np.ndarray):
  619. if not results:
  620. return
  621. class_counts = {}
  622. for r in results:
  623. class_counts[r.class_name] = class_counts.get(r.class_name, 0) + 1
  624. h, w = frame.shape[:2]
  625. logger.info(f"[YOLO] {model_type}: {len(results)}个目标 {class_counts} (帧尺寸={w}x{h})")
  626. def _detect_yolo(self, frame: np.ndarray) -> List[DetectedObject]:
  627. """使用 YOLO 模型检测"""
  628. results = []
  629. try:
  630. detections = self.model(
  631. frame,
  632. device=self.device,
  633. verbose=False,
  634. conf=self.config['confidence_threshold']
  635. )
  636. for det in detections:
  637. boxes = det.boxes
  638. if boxes is None:
  639. continue
  640. for i in range(len(boxes)):
  641. cls_id = int(boxes.cls[i])
  642. cls_name = det.names[cls_id]
  643. if cls_name not in self.config['target_classes']:
  644. continue
  645. conf = float(boxes.conf[i])
  646. xyxy = boxes.xyxy[i].cpu().numpy()
  647. x1, y1, x2, y2 = map(int, xyxy)
  648. width = x2 - x1
  649. height = y2 - y1
  650. if width < 10 or height < 10:
  651. continue
  652. center_x = x1 + width // 2
  653. center_y = y1 + height // 2
  654. obj = DetectedObject(
  655. class_name=cls_name,
  656. confidence=conf,
  657. bbox=(x1, y1, width, height),
  658. center=(center_x, center_y)
  659. )
  660. results.append(obj)
  661. except Exception as e:
  662. logger.error(f"YOLO11检测错误: {e}")
  663. return results
  664. def detect_with_keypoints(self, frame: np.ndarray) -> List[DetectedObject]:
  665. """
  666. 使用YOLO11-pose检测人体并返回关键点
  667. Args:
  668. frame: 输入图像
  669. Returns:
  670. 带关键点的检测结果列表
  671. """
  672. return self.detect(frame)
  673. def detect_persons(self, frame: np.ndarray) -> List[DetectedObject]:
  674. """检测人体(支持中英文类别名)"""
  675. all_detections = self.detect(frame)
  676. person_classes = {'person', '人'}
  677. return [obj for obj in all_detections if obj.class_name in person_classes]
  678. def release(self):
  679. """释放模型资源"""
  680. if hasattr(self, 'rknn') and self.rknn:
  681. self.rknn.release()
  682. self.rknn = None
  683. self.model = None
  684. self.session = None
  685. class PersonTracker:
  686. """
  687. 人体跟踪器
  688. 使用简单的质心跟踪算法
  689. """
  690. def __init__(self, max_disappeared: int = 30):
  691. """
  692. 初始化跟踪器
  693. Args:
  694. max_disappeared: 最大消失帧数
  695. """
  696. self.max_disappeared = max_disappeared
  697. self.next_id = 0
  698. self.objects = {} # id -> center
  699. self.disappeared = {} # id -> disappeared count
  700. def update(self, detections: List[DetectedObject]) -> List[DetectedObject]:
  701. """
  702. 更新跟踪状态
  703. Args:
  704. detections: 当前帧检测结果
  705. Returns:
  706. 带有跟踪ID的检测结果
  707. """
  708. # 如果没有检测结果
  709. if len(detections) == 0:
  710. # 标记所有已跟踪对象为消失
  711. for obj_id in list(self.disappeared.keys()):
  712. self.disappeared[obj_id] += 1
  713. if self.disappeared[obj_id] > self.max_disappeared:
  714. self._deregister(obj_id)
  715. return []
  716. # 计算当前检测中心点
  717. input_centers = np.array([d.center for d in detections])
  718. # 如果没有已跟踪对象
  719. if len(self.objects) == 0:
  720. for det in detections:
  721. self._register(det)
  722. else:
  723. # 计算距离矩阵
  724. object_ids = list(self.objects.keys())
  725. object_centers = np.array([self.objects[obj_id] for obj_id in object_ids])
  726. # 计算欧氏距离
  727. distances = np.linalg.norm(
  728. object_centers[:, np.newaxis] - input_centers,
  729. axis=2
  730. )
  731. # 匈牙利算法匹配 (简化版: 贪心匹配)
  732. rows = distances.min(axis=1).argsort()
  733. cols = distances.argmin(axis=1)[rows]
  734. used_rows = set()
  735. used_cols = set()
  736. for (row, col) in zip(rows, cols):
  737. if row in used_rows or col in used_cols:
  738. continue
  739. obj_id = object_ids[row]
  740. self.objects[obj_id] = input_centers[col]
  741. self.disappeared[obj_id] = 0
  742. detections[col].track_id = obj_id
  743. used_rows.add(row)
  744. used_cols.add(col)
  745. # 处理未匹配的已跟踪对象
  746. unused_rows = set(range(len(object_ids))) - used_rows
  747. for row in unused_rows:
  748. obj_id = object_ids[row]
  749. self.disappeared[obj_id] += 1
  750. if self.disappeared[obj_id] > self.max_disappeared:
  751. self._deregister(obj_id)
  752. # 处理未匹配的新检测
  753. unused_cols = set(range(len(input_centers))) - used_cols
  754. for col in unused_cols:
  755. self._register(detections[col])
  756. return [d for d in detections if d.track_id is not None]
  757. def _register(self, detection: DetectedObject):
  758. """注册新对象"""
  759. detection.track_id = self.next_id
  760. self.objects[self.next_id] = detection.center
  761. self.disappeared[self.next_id] = 0
  762. self.next_id += 1
  763. def _deregister(self, obj_id: int):
  764. """注销对象"""
  765. del self.objects[obj_id]
  766. del self.disappeared[obj_id]