panorama_camera.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. """
  2. 全景摄像头模块
  3. 负责获取视频流和物体检测
  4. """
  5. import cv2
  6. import numpy as np
  7. import threading
  8. import queue
  9. import time
  10. from typing import Optional, List, Tuple, Dict, Any
  11. from dataclasses import dataclass
  12. from config import PANORAMA_CAMERA, DETECTION_CONFIG
  13. from dahua_sdk import DahuaSDK, PTZCommand
  14. @dataclass
  15. class DetectedObject:
  16. """检测到的物体"""
  17. class_name: str # 类别名称
  18. confidence: float # 置信度
  19. bbox: Tuple[int, int, int, int] # 边界框 (x, y, width, height)
  20. center: Tuple[int, int] # 中心点坐标
  21. track_id: Optional[int] = None # 跟踪ID
  22. class PanoramaCamera:
  23. """全景摄像头类"""
  24. def __init__(self, sdk: DahuaSDK, camera_config: Dict = None):
  25. """
  26. 初始化全景摄像头
  27. Args:
  28. sdk: 大华SDK实例
  29. camera_config: 摄像头配置
  30. """
  31. self.sdk = sdk
  32. self.config = camera_config or PANORAMA_CAMERA
  33. self.login_handle = None
  34. self.play_handle = None
  35. self.connected = False
  36. # 视频流
  37. self.frame_queue = queue.Queue(maxsize=10)
  38. self.current_frame = None
  39. self.frame_lock = threading.Lock()
  40. self.rtsp_cap = None # RTSP视频捕获
  41. # 检测器
  42. self.detector = None
  43. # 控制标志
  44. self.running = False
  45. self.stream_thread = None
  46. # 断线重连
  47. self.auto_reconnect = True
  48. self.reconnect_interval = 5.0 # 重连间隔(秒)
  49. self.max_reconnect_attempts = 3 # 最大重连次数
  50. def connect(self) -> bool:
  51. """
  52. 连接摄像头
  53. Returns:
  54. 是否成功
  55. """
  56. login_handle, error = self.sdk.login(
  57. self.config['ip'],
  58. self.config['port'],
  59. self.config['username'],
  60. self.config['password']
  61. )
  62. if login_handle is None:
  63. print(f"连接全景摄像头失败: IP={self.config['ip']}, 错误码={error}")
  64. return False
  65. self.login_handle = login_handle
  66. self.connected = True
  67. print(f"成功连接全景摄像头: {self.config['ip']}")
  68. return True
  69. def disconnect(self):
  70. """断开连接"""
  71. self.stop_stream()
  72. if self.login_handle:
  73. self.sdk.logout(self.login_handle)
  74. self.login_handle = None
  75. self.connected = False
  76. def start_stream(self) -> bool:
  77. """
  78. 开始视频流
  79. Returns:
  80. 是否成功
  81. """
  82. if not self.connected:
  83. return False
  84. self.play_handle = self.sdk.real_play(
  85. self.login_handle,
  86. self.config['channel']
  87. )
  88. if self.play_handle is None:
  89. print("启动视频流失败")
  90. return False
  91. self.running = True
  92. self.stream_thread = threading.Thread(target=self._stream_worker, daemon=True)
  93. self.stream_thread.start()
  94. print("视频流已启动")
  95. return True
  96. def start_stream_rtsp(self, rtsp_url: str = None) -> bool:
  97. """
  98. 通过RTSP协议获取视频流
  99. Args:
  100. rtsp_url: RTSP地址,格式: rtsp://user:pass@ip:port/channel
  101. Returns:
  102. 是否成功
  103. """
  104. if rtsp_url is None:
  105. # 构建RTSP地址
  106. rtsp_url = f"rtsp://{self.config['username']}:{self.config['password']}@{self.config['ip']}:554/h264/ch{self.config['channel']}/main/av_stream"
  107. try:
  108. self.rtsp_cap = cv2.VideoCapture(rtsp_url)
  109. if not self.rtsp_cap.isOpened():
  110. print(f"无法打开RTSP流: {rtsp_url}")
  111. return False
  112. self.running = True
  113. self.stream_thread = threading.Thread(target=self._rtsp_stream_worker, daemon=True)
  114. self.stream_thread.start()
  115. print(f"RTSP视频流已启动: {rtsp_url}")
  116. return True
  117. except Exception as e:
  118. print(f"RTSP流启动失败: {e}")
  119. return False
  120. def _stream_worker(self):
  121. """视频流工作线程 (SDK模式)"""
  122. retry_count = 0
  123. max_retries = 10
  124. while self.running:
  125. try:
  126. # 尝试从 SDK 帧缓冲区获取帧 (如果可用)
  127. frame_buffer = self.sdk.get_video_frame_buffer(self.config['channel'])
  128. if frame_buffer:
  129. frame_info = frame_buffer.get(timeout=0.1)
  130. if frame_info and frame_info.get('data'):
  131. # 解码帧数据 (如果需要)
  132. # 注意: SDK回调返回的是编码数据,需要解码
  133. # 这里暂时跳过,因为解码需要额外处理
  134. pass
  135. # RTSP 模式获取帧 (推荐方式)
  136. if self.rtsp_cap is not None and self.rtsp_cap.isOpened():
  137. ret, frame = self.rtsp_cap.read()
  138. if ret and frame is not None:
  139. with self.frame_lock:
  140. self.current_frame = frame.copy()
  141. try:
  142. self.frame_queue.put(frame.copy(), block=False)
  143. except queue.Full:
  144. pass
  145. retry_count = 0 # 重置重试计数
  146. time.sleep(0.001) # 减少CPU占用
  147. continue
  148. # 如果 RTSP 不可用,尝试自动连接
  149. if retry_count < max_retries:
  150. rtsp_url = self._build_rtsp_url()
  151. try:
  152. if self.rtsp_cap is None:
  153. self.rtsp_cap = cv2.VideoCapture(rtsp_url)
  154. self.rtsp_cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) # 减少缓冲延迟
  155. if self.rtsp_cap.isOpened():
  156. retry_count = 0
  157. continue
  158. except Exception as e:
  159. pass
  160. retry_count += 1
  161. time.sleep(1.0) # 重试间隔
  162. else:
  163. # 超过最大重试次数,使用模拟帧
  164. frame = np.zeros((1080, 1920, 3), dtype=np.uint8)
  165. with self.frame_lock:
  166. self.current_frame = frame
  167. try:
  168. self.frame_queue.put(frame, block=False)
  169. except queue.Full:
  170. pass
  171. time.sleep(0.1)
  172. except Exception as e:
  173. print(f"视频流错误: {e}")
  174. time.sleep(0.1)
  175. def _build_rtsp_url(self) -> str:
  176. """构建 RTSP URL"""
  177. return f"rtsp://{self.config['username']}:{self.config['password']}@{self.config['ip']}:554/h264/ch{self.config['channel']}/main/av_stream"
  178. def _rtsp_stream_worker(self):
  179. """RTSP视频流工作线程"""
  180. while self.running:
  181. try:
  182. if self.rtsp_cap is None or not self.rtsp_cap.isOpened():
  183. time.sleep(0.1)
  184. continue
  185. ret, frame = self.rtsp_cap.read()
  186. if not ret or frame is None:
  187. time.sleep(0.01)
  188. continue
  189. with self.frame_lock:
  190. self.current_frame = frame.copy()
  191. try:
  192. self.frame_queue.put(frame, block=False)
  193. except queue.Full:
  194. pass
  195. except Exception as e:
  196. print(f"RTSP视频流错误: {e}")
  197. time.sleep(0.1)
  198. def stop_stream(self):
  199. """停止视频流"""
  200. self.running = False
  201. if self.stream_thread:
  202. self.stream_thread.join(timeout=2)
  203. if self.play_handle:
  204. self.sdk.stop_real_play(self.play_handle)
  205. self.play_handle = None
  206. if self.rtsp_cap:
  207. self.rtsp_cap.release()
  208. self.rtsp_cap = None
  209. def get_frame(self) -> Optional[np.ndarray]:
  210. """
  211. 获取当前帧
  212. Returns:
  213. 当前帧图像
  214. """
  215. with self.frame_lock:
  216. return self.current_frame.copy() if self.current_frame is not None else None
  217. def get_frame_from_queue(self, timeout: float = 0.1) -> Optional[np.ndarray]:
  218. """
  219. 从帧队列获取帧 (用于批量处理)
  220. Args:
  221. timeout: 等待超时时间
  222. Returns:
  223. 帧图像或None
  224. """
  225. try:
  226. return self.frame_queue.get(timeout=timeout)
  227. except:
  228. return None
  229. def get_frame_buffer(self, count: int = 5) -> List[np.ndarray]:
  230. """
  231. 获取帧缓冲 (用于运动检测等需要多帧的场景)
  232. Args:
  233. count: 获取帧数
  234. Returns:
  235. 帧列表
  236. """
  237. frames = []
  238. while len(frames) < count:
  239. frame = self.get_frame_from_queue(timeout=0.05)
  240. if frame is not None:
  241. frames.append(frame)
  242. else:
  243. break
  244. return frames
  245. def set_detector(self, detector):
  246. """设置物体检测器"""
  247. self.detector = detector
  248. def detect_objects(self, frame: np.ndarray = None) -> List[DetectedObject]:
  249. """
  250. 检测物体
  251. Args:
  252. frame: 输入帧,如果为None则使用当前帧
  253. Returns:
  254. 检测到的物体列表
  255. """
  256. if frame is None:
  257. frame = self.get_frame()
  258. if frame is None or self.detector is None:
  259. return []
  260. return self.detector.detect(frame)
  261. def get_detection_position(self, obj: DetectedObject,
  262. frame_size: Tuple[int, int]) -> Tuple[float, float]:
  263. """
  264. 获取检测物体在画面中的相对位置
  265. Args:
  266. obj: 检测到的物体
  267. frame_size: 画面尺寸 (width, height)
  268. Returns:
  269. 相对位置 (x_ratio, y_ratio) 范围0-1
  270. """
  271. width, height = frame_size
  272. x_ratio = obj.center[0] / width
  273. y_ratio = obj.center[1] / height
  274. return (x_ratio, y_ratio)
  275. class ObjectDetector:
  276. """
  277. 物体检测器
  278. 使用YOLO11模型进行人体检测
  279. """
  280. def __init__(self, model_path: str = None, use_gpu: bool = True, model_size: str = 'n'):
  281. """
  282. 初始化检测器
  283. Args:
  284. model_path: 模型路径 (自定义模型)
  285. use_gpu: 是否使用GPU
  286. model_size: 模型尺寸 ('n', 's', 'm', 'l', 'x')
  287. """
  288. self.model = None
  289. self.model_path = model_path
  290. self.use_gpu = use_gpu
  291. self.model_size = model_size
  292. self.config = DETECTION_CONFIG
  293. self.device = 'cuda:0' if use_gpu else 'cpu'
  294. self._load_model()
  295. def _load_model(self):
  296. """加载YOLO11检测模型"""
  297. try:
  298. # 使用ultralytics YOLO11
  299. from ultralytics import YOLO
  300. if self.model_path:
  301. # 使用自定义模型
  302. self.model = YOLO(self.model_path)
  303. else:
  304. # 使用YOLO11预训练模型
  305. # YOLO11模型命名: yolo11n.pt, yolo11s.pt, yolo11m.pt, yolo11l.pt, yolo11x.pt
  306. model_name = f'yolo11{self.model_size}.pt'
  307. self.model = YOLO(model_name)
  308. # 预热模型
  309. dummy = np.zeros((640, 640, 3), dtype=np.uint8)
  310. self.model(dummy, device=self.device, verbose=False)
  311. print(f"成功加载YOLO11检测模型 (device={self.device})")
  312. except ImportError:
  313. print("未安装ultralytics,请运行: pip install ultralytics")
  314. self._load_opencv_model()
  315. except Exception as e:
  316. print(f"加载YOLO11模型失败: {e}")
  317. self._load_opencv_model()
  318. def _load_opencv_model(self):
  319. """使用OpenCV加载模型"""
  320. # 可以加载ONNX模型
  321. pass
  322. def detect(self, frame: np.ndarray) -> List[DetectedObject]:
  323. """
  324. 使用YOLO11检测物体
  325. Args:
  326. frame: 输入图像
  327. Returns:
  328. 检测结果列表
  329. """
  330. if self.model is None or frame is None:
  331. return []
  332. results = []
  333. try:
  334. # YOLO11推理
  335. detections = self.model(
  336. frame,
  337. device=self.device,
  338. verbose=False,
  339. conf=self.config['confidence_threshold']
  340. )
  341. for det in detections:
  342. boxes = det.boxes
  343. if boxes is None:
  344. continue
  345. for i in range(len(boxes)):
  346. # 获取类别
  347. cls_id = int(boxes.cls[i])
  348. cls_name = det.names[cls_id]
  349. # 过滤目标类别
  350. if cls_name not in self.config['target_classes']:
  351. continue
  352. # 获取置信度
  353. conf = float(boxes.conf[i])
  354. # 获取边界框
  355. xyxy = boxes.xyxy[i].cpu().numpy()
  356. x1, y1, x2, y2 = map(int, xyxy)
  357. width = x2 - x1
  358. height = y2 - y1
  359. # 过滤过小的检测框
  360. if width < 10 or height < 10:
  361. continue
  362. # 计算中心点
  363. center_x = x1 + width // 2
  364. center_y = y1 + height // 2
  365. obj = DetectedObject(
  366. class_name=cls_name,
  367. confidence=conf,
  368. bbox=(x1, y1, width, height),
  369. center=(center_x, center_y)
  370. )
  371. results.append(obj)
  372. except Exception as e:
  373. print(f"YOLO11检测错误: {e}")
  374. return results
  375. def detect_with_keypoints(self, frame: np.ndarray) -> List[DetectedObject]:
  376. """
  377. 使用YOLO11-pose检测人体并返回关键点
  378. Args:
  379. frame: 输入图像
  380. Returns:
  381. 带关键点的检测结果列表
  382. """
  383. # 如果使用pose模型,可以获取人体关键点
  384. # 用于更精确的人体定位
  385. return self.detect(frame)
  386. def detect_persons(self, frame: np.ndarray) -> List[DetectedObject]:
  387. """
  388. 检测人体
  389. Args:
  390. frame: 输入图像
  391. Returns:
  392. 检测到的人体列表
  393. """
  394. results = self.detect(frame)
  395. return [obj for obj in results if obj.class_name == 'person']
  396. class PersonTracker:
  397. """
  398. 人体跟踪器
  399. 使用简单的质心跟踪算法
  400. """
  401. def __init__(self, max_disappeared: int = 30):
  402. """
  403. 初始化跟踪器
  404. Args:
  405. max_disappeared: 最大消失帧数
  406. """
  407. self.max_disappeared = max_disappeared
  408. self.next_id = 0
  409. self.objects = {} # id -> center
  410. self.disappeared = {} # id -> disappeared count
  411. def update(self, detections: List[DetectedObject]) -> List[DetectedObject]:
  412. """
  413. 更新跟踪状态
  414. Args:
  415. detections: 当前帧检测结果
  416. Returns:
  417. 带有跟踪ID的检测结果
  418. """
  419. # 如果没有检测结果
  420. if len(detections) == 0:
  421. # 标记所有已跟踪对象为消失
  422. for obj_id in list(self.disappeared.keys()):
  423. self.disappeared[obj_id] += 1
  424. if self.disappeared[obj_id] > self.max_disappeared:
  425. self._deregister(obj_id)
  426. return []
  427. # 计算当前检测中心点
  428. input_centers = np.array([d.center for d in detections])
  429. # 如果没有已跟踪对象
  430. if len(self.objects) == 0:
  431. for det in detections:
  432. self._register(det)
  433. else:
  434. # 计算距离矩阵
  435. object_ids = list(self.objects.keys())
  436. object_centers = np.array([self.objects[obj_id] for obj_id in object_ids])
  437. # 计算欧氏距离
  438. distances = np.linalg.norm(
  439. object_centers[:, np.newaxis] - input_centers,
  440. axis=2
  441. )
  442. # 匈牙利算法匹配 (简化版: 贪心匹配)
  443. rows = distances.min(axis=1).argsort()
  444. cols = distances.argmin(axis=1)[rows]
  445. used_rows = set()
  446. used_cols = set()
  447. for (row, col) in zip(rows, cols):
  448. if row in used_rows or col in used_cols:
  449. continue
  450. obj_id = object_ids[row]
  451. self.objects[obj_id] = input_centers[col]
  452. self.disappeared[obj_id] = 0
  453. detections[col].track_id = obj_id
  454. used_rows.add(row)
  455. used_cols.add(col)
  456. # 处理未匹配的已跟踪对象
  457. unused_rows = set(range(len(object_ids))) - used_rows
  458. for row in unused_rows:
  459. obj_id = object_ids[row]
  460. self.disappeared[obj_id] += 1
  461. if self.disappeared[obj_id] > self.max_disappeared:
  462. self._deregister(obj_id)
  463. # 处理未匹配的新检测
  464. unused_cols = set(range(len(input_centers))) - used_cols
  465. for col in unused_cols:
  466. self._register(detections[col])
  467. return [d for d in detections if d.track_id is not None]
  468. def _register(self, detection: DetectedObject):
  469. """注册新对象"""
  470. detection.track_id = self.next_id
  471. self.objects[self.next_id] = detection.center
  472. self.disappeared[self.next_id] = 0
  473. self.next_id += 1
  474. def _deregister(self, obj_id: int):
  475. """注销对象"""
  476. del self.objects[obj_id]
  477. del self.disappeared[obj_id]