|
|
@@ -7,6 +7,7 @@ import time
|
|
|
import threading
|
|
|
import queue
|
|
|
import logging
|
|
|
+import math
|
|
|
from typing import Optional, List, Dict, Tuple, Callable
|
|
|
from dataclasses import dataclass, field
|
|
|
from enum import Enum
|
|
|
@@ -393,8 +394,8 @@ class Coordinator:
|
|
|
def _coordinator_worker(self):
|
|
|
"""联动工作线程"""
|
|
|
last_detection_time = 0
|
|
|
- # 优先使用 detection_fps,默认每秒2帧
|
|
|
- detection_fps = self.config.get('detection_fps', 2)
|
|
|
+ # 从 DETECTION_CONFIG 获取检测帧率,默认每秒2帧
|
|
|
+ detection_fps = self.config.get('detection_fps', DETECTION_CONFIG.get('detection_fps', 2))
|
|
|
detection_interval = 1.0 / detection_fps # 根据FPS计算间隔
|
|
|
|
|
|
# 初始化统计
|
|
|
@@ -427,11 +428,7 @@ class Coordinator:
|
|
|
if detections:
|
|
|
self._update_stats('persons_detected', len(detections))
|
|
|
|
|
|
- # 为检测结果分配临时序号
|
|
|
- for idx, det in enumerate(detections):
|
|
|
- det.track_id = idx
|
|
|
-
|
|
|
- # 更新跟踪目标
|
|
|
+ # 更新跟踪目标(track_id 在此方法内分配)
|
|
|
self._update_tracking_targets(detections, frame_size)
|
|
|
|
|
|
# 处理检测结果
|
|
|
@@ -458,7 +455,10 @@ class Coordinator:
|
|
|
|
|
|
def _update_tracking_targets(self, detections: List[DetectedObject],
|
|
|
frame_size: Tuple[int, int]):
|
|
|
- """更新跟踪目标(仅添加有效人员)"""
|
|
|
+ """更新跟踪目标(跨帧匹配,支持粘性跟踪)
|
|
|
+
|
|
|
+ 改进:不再每轮清空目标,而是使用位置匹配关联连续帧的目标
|
|
|
+ """
|
|
|
current_time = time.time()
|
|
|
frame_w, frame_h = frame_size
|
|
|
center_x, center_y = frame_w / 2, frame_h / 2
|
|
|
@@ -466,36 +466,94 @@ class Coordinator:
|
|
|
# 获取人员置信度阈值
|
|
|
person_threshold = DETECTION_CONFIG.get('person_threshold', 0.8)
|
|
|
|
|
|
+ # 过滤有效人员
|
|
|
+ valid_detections = []
|
|
|
+ for det in detections:
|
|
|
+ if det.class_name != 'person':
|
|
|
+ continue
|
|
|
+ if det.confidence < person_threshold:
|
|
|
+ continue
|
|
|
+ valid_detections.append(det)
|
|
|
+
|
|
|
with self.targets_lock:
|
|
|
- # 清空上一轮目标(不再跟踪,每轮独立)
|
|
|
- self.tracking_targets.clear()
|
|
|
+ # 匹配阈值:位置距离小于此值认为是同一目标
|
|
|
+ MATCH_THRESHOLD = 0.15 # 画面比例
|
|
|
|
|
|
- # 只添加有效人员(class_name == 'person' 且置信度达标)
|
|
|
- for det in detections:
|
|
|
- if det.track_id is None:
|
|
|
- continue
|
|
|
+ # 已匹配的检测索引
|
|
|
+ matched_det_indices = set()
|
|
|
+
|
|
|
+ # 步骤1:尝试匹配现有目标
|
|
|
+ for track_id, target in list(self.tracking_targets.items()):
|
|
|
+ best_match_idx = None
|
|
|
+ best_match_dist = MATCH_THRESHOLD
|
|
|
|
|
|
- # 严格过滤:只处理人员且置信度达标
|
|
|
- if det.class_name != 'person':
|
|
|
- continue
|
|
|
- if det.confidence < person_threshold:
|
|
|
+ for idx, det in enumerate(valid_detections):
|
|
|
+ if idx in matched_det_indices:
|
|
|
+ continue
|
|
|
+
|
|
|
+ det_x = det.center[0] / frame_w
|
|
|
+ det_y = det.center[1] / frame_h
|
|
|
+
|
|
|
+ # 计算位置距离
|
|
|
+ dist = math.sqrt(
|
|
|
+ (det_x - target.position[0]) ** 2 +
|
|
|
+ (det_y - target.position[1]) ** 2
|
|
|
+ )
|
|
|
+
|
|
|
+ if dist < best_match_dist:
|
|
|
+ best_match_dist = dist
|
|
|
+ best_match_idx = idx
|
|
|
+
|
|
|
+
|
|
|
+ if best_match_idx is not None:
|
|
|
+ # 找到匹配,更新目标
|
|
|
+ det = valid_detections[best_match_idx]
|
|
|
+ matched_det_indices.add(best_match_idx)
|
|
|
+
|
|
|
+ x_ratio = det.center[0] / frame_w
|
|
|
+ y_ratio = det.center[1] / frame_h
|
|
|
+ _, _, width, height = det.bbox
|
|
|
+ area = width * height
|
|
|
+
|
|
|
+ dx = abs(det.center[0] - center_x) / center_x
|
|
|
+ dy = abs(det.center[1] - center_y) / center_y
|
|
|
+ center_distance = (dx + dy) / 2
|
|
|
+
|
|
|
+ # 更新目标属性
|
|
|
+ self.tracking_targets[track_id] = TrackingTarget(
|
|
|
+ track_id=track_id,
|
|
|
+ position=(x_ratio, y_ratio),
|
|
|
+ last_update=current_time,
|
|
|
+ area=area,
|
|
|
+ confidence=det.confidence,
|
|
|
+ center_distance=center_distance,
|
|
|
+ person_info=target.person_info # 保留之前识别的信息
|
|
|
+ )
|
|
|
+
|
|
|
+ # 步骤2:为未匹配的检测创建新目标
|
|
|
+ for idx, det in enumerate(valid_detections):
|
|
|
+ if idx in matched_det_indices:
|
|
|
continue
|
|
|
|
|
|
x_ratio = det.center[0] / frame_w
|
|
|
y_ratio = det.center[1] / frame_h
|
|
|
-
|
|
|
- # 计算面积
|
|
|
_, _, width, height = det.bbox
|
|
|
area = width * height
|
|
|
|
|
|
- # 计算到画面中心的距离比例
|
|
|
dx = abs(det.center[0] - center_x) / center_x
|
|
|
dy = abs(det.center[1] - center_y) / center_y
|
|
|
- center_distance = (dx + dy) / 2 # 归一化到0-1
|
|
|
+ center_distance = (dx + dy) / 2
|
|
|
+
|
|
|
+ # 分配全局唯一track_id
|
|
|
+ with self._track_id_lock:
|
|
|
+ new_track_id = self._next_track_id
|
|
|
+ self._next_track_id += 1
|
|
|
|
|
|
- # 添加为跟踪目标
|
|
|
- self.tracking_targets[det.track_id] = TrackingTarget(
|
|
|
- track_id=det.track_id,
|
|
|
+
|
|
|
+ det.track_id = new_track_id # 更新检测对象的track_id
|
|
|
+
|
|
|
+ self.tracking_targets[new_track_id] = TrackingTarget(
|
|
|
+ track_id=new_track_id,
|
|
|
position=(x_ratio, y_ratio),
|
|
|
last_update=current_time,
|
|
|
area=area,
|
|
|
@@ -779,8 +837,13 @@ class AsyncCoordinator(Coordinator):
|
|
|
# PTZ确认回调
|
|
|
self._on_ptz_confirmed: Optional[Callable] = None
|
|
|
|
|
|
- # 上次PTZ命令时间
|
|
|
+ # 上次PTZ命令时间(添加线程锁保护)
|
|
|
self._last_ptz_time = 0.0
|
|
|
+ self._last_ptz_time_lock = threading.Lock()
|
|
|
+
|
|
|
+ # 跨帧跟踪:全局track_id计数器
|
|
|
+ self._next_track_id = 1
|
|
|
+ self._track_id_lock = threading.Lock()
|
|
|
|
|
|
# 配对图片保存器
|
|
|
self._enable_paired_saving = DETECTION_CONFIG.get('enable_paired_saving', False)
|
|
|
@@ -877,8 +940,8 @@ class AsyncCoordinator(Coordinator):
|
|
|
def _detection_worker(self):
|
|
|
"""检测线程:持续读帧 + YOLO推理 + 发送PTZ命令 + 打印检测日志"""
|
|
|
last_detection_time = 0
|
|
|
- # 优先使用 detection_fps,默认每秒2帧
|
|
|
- detection_fps = self.config.get('detection_fps', 2)
|
|
|
+ # 从 DETECTION_CONFIG 获取检测帧率,默认每秒2帧
|
|
|
+ detection_fps = self.config.get('detection_fps', DETECTION_CONFIG.get('detection_fps', 2))
|
|
|
detection_interval = 1.0 / detection_fps # 根据FPS计算间隔
|
|
|
ptz_cooldown = self.config.get('ptz_command_cooldown', 0.5)
|
|
|
ptz_threshold = self.config.get('ptz_position_threshold', 0.03)
|
|
|
@@ -946,12 +1009,10 @@ class AsyncCoordinator(Coordinator):
|
|
|
self._update_stats('persons_detected', len(detections))
|
|
|
detection_person_count += 1
|
|
|
|
|
|
- # 为检测结果分配临时序号
|
|
|
- for idx, det in enumerate(detections):
|
|
|
- det.track_id = idx
|
|
|
+ # 更新跟踪目标(track_id 在此方法内分配)
|
|
|
self._update_tracking_targets(detections, frame_size)
|
|
|
|
|
|
- # 配对图片保存:创建新批次
|
|
|
+ # 配对图片保存:创建新批次(在 _update_tracking_targets 之后,使用正确的 track_id)
|
|
|
if detections and self._enable_paired_saving and self._paired_saver is not None:
|
|
|
self._create_detection_batch(frame, detections, frame_size)
|
|
|
|
|
|
@@ -1197,10 +1258,11 @@ class AsyncCoordinator(Coordinator):
|
|
|
abs(y_ratio - last_y) < self.ptz_position_threshold:
|
|
|
return
|
|
|
|
|
|
- # 冷却检查
|
|
|
+ # 冷却检查(线程安全)
|
|
|
current_time = time.time()
|
|
|
- if current_time - self._last_ptz_time < self.PTZ_COMMAND_COOLDOWN:
|
|
|
- return
|
|
|
+ with self._last_ptz_time_lock:
|
|
|
+ if current_time - self._last_ptz_time < self.PTZ_COMMAND_COOLDOWN:
|
|
|
+ return
|
|
|
|
|
|
cmd = PTZCommand(
|
|
|
pan=0, tilt=0, zoom=0,
|
|
|
@@ -1218,10 +1280,11 @@ class AsyncCoordinator(Coordinator):
|
|
|
"""发送PTZ命令并打印日志"""
|
|
|
x_ratio, y_ratio = target.position
|
|
|
|
|
|
- # 冷却检查(与 _send_ptz_command 保持一致)
|
|
|
+ # 冷却检查(线程安全)
|
|
|
current_time = time.time()
|
|
|
- if current_time - self._last_ptz_time < self.PTZ_COMMAND_COOLDOWN:
|
|
|
- return
|
|
|
+ with self._last_ptz_time_lock:
|
|
|
+ if current_time - self._last_ptz_time < self.PTZ_COMMAND_COOLDOWN:
|
|
|
+ return
|
|
|
|
|
|
# 位置变化阈值检查
|
|
|
if self.last_ptz_position is not None:
|
|
|
@@ -1271,7 +1334,9 @@ class AsyncCoordinator(Coordinator):
|
|
|
Args:
|
|
|
cmd: PTZ命令(包含 batch_id, person_index, track_id 用于配对保存)
|
|
|
"""
|
|
|
- self._last_ptz_time = time.time()
|
|
|
+ # 更新最后执行时间(线程安全)
|
|
|
+ with self._last_ptz_time_lock:
|
|
|
+ self._last_ptz_time = time.time()
|
|
|
|
|
|
# 从命令中提取配对保存相关信息
|
|
|
track_id = cmd.track_id
|