|
@@ -0,0 +1,260 @@
|
|
|
|
|
+"""
|
|
|
|
|
+Ultralytics Tracker 封装
|
|
|
|
|
+支持 YOLO (.pt) 端到端跟踪 和 RKNN/ONNX 检测 + BYTETracker 关联
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+import os
|
|
|
|
|
+from typing import List, Tuple, Optional
|
|
|
|
|
+from dataclasses import dataclass
|
|
|
|
|
+
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+
|
|
|
|
|
+from config import TRACKING_CONFIG
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@dataclass
|
|
|
|
|
+class TrackedPerson:
|
|
|
|
|
+ """跟踪目标"""
|
|
|
|
|
+ track_id: int
|
|
|
|
|
+ bbox: Tuple[int, int, int, int] # x1, y1, x2, y2
|
|
|
|
|
+ center: Tuple[int, int]
|
|
|
|
|
+ confidence: float
|
|
|
|
|
+ class_name: str = "person"
|
|
|
|
|
+ lost: bool = False
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def resolve_model(model_path: Optional[str], model_type: str) -> Tuple[str, str]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 解析模型路径和类型
|
|
|
|
|
+ 优先级:model_path > TRACKING_CONFIG['fallback_model_path'] > yolo11n.pt 自动下载
|
|
|
|
|
+ """
|
|
|
|
|
+ if model_path and os.path.exists(model_path):
|
|
|
|
|
+ ext = os.path.splitext(model_path)[1].lower()
|
|
|
|
|
+ if ext == ".rknn":
|
|
|
|
|
+ return model_path, "rknn"
|
|
|
|
|
+ elif ext == ".onnx":
|
|
|
|
|
+ return model_path, "onnx"
|
|
|
|
|
+ elif ext == ".pt":
|
|
|
|
|
+ return model_path, "yolo"
|
|
|
|
|
+
|
|
|
|
|
+ # 尝试 fallback 路径
|
|
|
|
|
+ fallback = TRACKING_CONFIG.get("fallback_model_path")
|
|
|
|
|
+ if fallback and os.path.exists(fallback):
|
|
|
|
|
+ ext = os.path.splitext(fallback)[1].lower()
|
|
|
|
|
+ if ext == ".rknn":
|
|
|
|
|
+ return fallback, "rknn"
|
|
|
|
|
+ elif ext == ".onnx":
|
|
|
|
|
+ return fallback, "onnx"
|
|
|
|
|
+ return fallback, "yolo"
|
|
|
|
|
+
|
|
|
|
|
+ # 最终回退:Ultralytics 自动下载
|
|
|
|
|
+ return "yolo11n.pt", "yolo"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class UltralyticsTracker:
|
|
|
|
|
+ """Ultralytics 跟踪器封装"""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(
|
|
|
|
|
+ self,
|
|
|
|
|
+ model_path: Optional[str] = None,
|
|
|
|
|
+ model_type: str = "auto",
|
|
|
|
|
+ use_gpu: bool = True,
|
|
|
|
|
+ tracker_type: str = "bytetrack",
|
|
|
|
|
+ conf_threshold: float = 0.5,
|
|
|
|
|
+ person_threshold: float = 0.5,
|
|
|
|
|
+ max_lost: int = 30,
|
|
|
|
|
+ ):
|
|
|
|
|
+ if model_path is None:
|
|
|
|
|
+ model_path = TRACKING_CONFIG["model_path"]
|
|
|
|
|
+
|
|
|
|
|
+ self.model_path = model_path
|
|
|
|
|
+ self.model_type = model_type
|
|
|
|
|
+ self.use_gpu = use_gpu
|
|
|
|
|
+ self.tracker_type = tracker_type
|
|
|
|
|
+ self.conf_threshold = conf_threshold
|
|
|
|
|
+ self.person_threshold = person_threshold
|
|
|
|
|
+ self.max_lost = max_lost
|
|
|
|
|
+
|
|
|
|
|
+ self.model = None
|
|
|
|
|
+ self.rknn_detector = None
|
|
|
|
|
+ self.byte_tracker = None
|
|
|
|
|
+
|
|
|
|
|
+ resolved_path, resolved_type = resolve_model(model_path, model_type)
|
|
|
|
|
+ self.model_path = resolved_path
|
|
|
|
|
+ self.model_type = resolved_type
|
|
|
|
|
+
|
|
|
|
|
+ self._load_model()
|
|
|
|
|
+
|
|
|
|
|
+ def _load_model(self):
|
|
|
|
|
+ if self.model_type == "rknn":
|
|
|
|
|
+ self._load_rknn_model()
|
|
|
|
|
+ elif self.model_type == "onnx":
|
|
|
|
|
+ self._load_onnx_model()
|
|
|
|
|
+ else:
|
|
|
|
|
+ self._load_yolo_model()
|
|
|
|
|
+
|
|
|
|
|
+ def _load_yolo_model(self):
|
|
|
|
|
+ from ultralytics import YOLO
|
|
|
|
|
+ self.model = YOLO(self.model_path)
|
|
|
|
|
+ dummy = np.zeros((640, 640, 3), dtype=np.uint8)
|
|
|
|
|
+ device = "cuda:0" if self.use_gpu else "cpu"
|
|
|
|
|
+ self.model(dummy, task="track", tracker=f"{self.tracker_type}.yaml", persist=True, verbose=False, device=device)
|
|
|
|
|
+ print(f"YOLO 跟踪模型加载成功: {self.model_path}")
|
|
|
|
|
+
|
|
|
|
|
+ def _load_rknn_model(self):
|
|
|
|
|
+ from safety_detector import RKNNDetector
|
|
|
|
|
+ self.rknn_detector = RKNNDetector(self.model_path)
|
|
|
|
|
+ self._init_byte_tracker()
|
|
|
|
|
+ print(f"RKNN 跟踪模型加载成功: {self.model_path}")
|
|
|
|
|
+
|
|
|
|
|
+ def _load_onnx_model(self):
|
|
|
|
|
+ from safety_detector import ONNXDetector
|
|
|
|
|
+ self.rknn_detector = ONNXDetector(self.model_path)
|
|
|
|
|
+ self._init_byte_tracker()
|
|
|
|
|
+ print(f"ONNX 跟踪模型加载成功: {self.model_path}")
|
|
|
|
|
+
|
|
|
|
|
+ def _init_byte_tracker(self):
|
|
|
|
|
+ try:
|
|
|
|
|
+ from ultralytics.trackers.byte_tracker import BYTETracker
|
|
|
|
|
+ self.byte_tracker = BYTETracker(args=self._tracker_args())
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"初始化 BYTETracker 失败: {e},将使用简化 IOU 关联")
|
|
|
|
|
+ self.byte_tracker = None
|
|
|
|
|
+
|
|
|
|
|
+ def _tracker_args(self):
|
|
|
|
|
+ class Args:
|
|
|
|
|
+ track_thresh = self.conf_threshold
|
|
|
|
|
+ match_thresh = 0.8
|
|
|
|
|
+ track_buffer = self.max_lost
|
|
|
|
|
+ mot20 = False
|
|
|
|
|
+ return Args()
|
|
|
|
|
+
|
|
|
|
|
+ def update(self, frame: np.ndarray) -> List[TrackedPerson]:
|
|
|
|
|
+ if frame is None:
|
|
|
|
|
+ return []
|
|
|
|
|
+ if self.model_type == "yolo":
|
|
|
|
|
+ return self._update_yolo(frame)
|
|
|
|
|
+ else:
|
|
|
|
|
+ return self._update_rknn_onnx(frame)
|
|
|
|
|
+
|
|
|
|
|
+ def _update_yolo(self, frame: np.ndarray) -> List[TrackedPerson]:
|
|
|
|
|
+ device = "cuda:0" if self.use_gpu else "cpu"
|
|
|
|
|
+ results = self.model(
|
|
|
|
|
+ frame,
|
|
|
|
|
+ task="track",
|
|
|
|
|
+ tracker=f"{self.tracker_type}.yaml",
|
|
|
|
|
+ persist=True,
|
|
|
|
|
+ conf=self.conf_threshold,
|
|
|
|
|
+ verbose=False,
|
|
|
|
|
+ device=device,
|
|
|
|
|
+ )
|
|
|
|
|
+ return self._parse_yolo_results(results, frame.shape)
|
|
|
|
|
+
|
|
|
|
|
+ def _detect_yolo(self, frame: np.ndarray) -> List[TrackedPerson]:
|
|
|
|
|
+ """仅供测试/mock 使用的 YOLO 检测入口,返回解析后的跟踪目标。"""
|
|
|
|
|
+ device = "cuda:0" if self.use_gpu else "cpu"
|
|
|
|
|
+ results = self.model(
|
|
|
|
|
+ frame,
|
|
|
|
|
+ task="track",
|
|
|
|
|
+ tracker=f"{self.tracker_type}.yaml",
|
|
|
|
|
+ persist=True,
|
|
|
|
|
+ conf=self.conf_threshold,
|
|
|
|
|
+ verbose=False,
|
|
|
|
|
+ device=device,
|
|
|
|
|
+ )
|
|
|
|
|
+ return self._parse_yolo_results(results, frame.shape)
|
|
|
|
|
+
|
|
|
|
|
+ def _parse_yolo_results(self, results, frame_shape) -> List[TrackedPerson]:
|
|
|
|
|
+ persons = []
|
|
|
|
|
+ h, w = frame_shape[:2]
|
|
|
|
|
+ for det in results:
|
|
|
|
|
+ boxes = det.boxes
|
|
|
|
|
+ if boxes is None or len(boxes) == 0:
|
|
|
|
|
+ continue
|
|
|
|
|
+ for i in range(len(boxes)):
|
|
|
|
|
+ cls_id = int(boxes.cls[i])
|
|
|
|
|
+ cls_name = det.names.get(cls_id, str(cls_id))
|
|
|
|
|
+ if cls_name != "person":
|
|
|
|
|
+ continue
|
|
|
|
|
+ conf = float(boxes.conf[i])
|
|
|
|
|
+ if conf < self.person_threshold:
|
|
|
|
|
+ continue
|
|
|
|
|
+ xyxy = boxes.xyxy[i]
|
|
|
|
|
+ if hasattr(xyxy, "cpu"):
|
|
|
|
|
+ xyxy = xyxy.cpu().numpy()
|
|
|
|
|
+ x1, y1, x2, y2 = map(int, xyxy)
|
|
|
|
|
+ track_id = int(boxes.id[i]) if boxes.id is not None else -1
|
|
|
|
|
+ center_x = (x1 + x2) // 2
|
|
|
|
|
+ center_y = (y1 + y2) // 2
|
|
|
|
|
+ persons.append(TrackedPerson(
|
|
|
|
|
+ track_id=track_id,
|
|
|
|
|
+ bbox=(x1, y1, x2, y2),
|
|
|
|
|
+ center=(center_x, center_y),
|
|
|
|
|
+ confidence=conf,
|
|
|
|
|
+ ))
|
|
|
|
|
+ return persons
|
|
|
|
|
+
|
|
|
|
|
+ def _update_rknn_onnx(self, frame: np.ndarray) -> List[TrackedPerson]:
|
|
|
|
|
+ from safety_detector import Detection
|
|
|
|
|
+ conf_map = {3: self.person_threshold}
|
|
|
|
|
+ detections = self.rknn_detector.detect(frame, conf_map)
|
|
|
|
|
+ # 只保留 person
|
|
|
|
|
+ person_dets = [d for d in detections if d.class_id == 3]
|
|
|
|
|
+ if not person_dets:
|
|
|
|
|
+ return []
|
|
|
|
|
+
|
|
|
|
|
+ if self.byte_tracker is None:
|
|
|
|
|
+ return self._simple_association(person_dets)
|
|
|
|
|
+
|
|
|
|
|
+ # 构造 BYTETracker 输入 [x1, y1, x2, y2, conf, cls]
|
|
|
|
|
+ try:
|
|
|
|
|
+ import torch
|
|
|
|
|
+ dets = []
|
|
|
|
|
+ for d in person_dets:
|
|
|
|
|
+ x1, y1, x2, y2 = d.bbox
|
|
|
|
|
+ dets.append([x1, y1, x2, y2, d.confidence, d.class_id])
|
|
|
|
|
+ dets_t = torch.tensor(dets, dtype=torch.float32)
|
|
|
|
|
+ tracks = self.byte_tracker.update(dets_t, frame.shape)
|
|
|
|
|
+ persons = []
|
|
|
|
|
+ for t in tracks:
|
|
|
|
|
+ x1, y1, x2, y2 = map(int, t.tlbr)
|
|
|
|
|
+ center_x = (x1 + x2) // 2
|
|
|
|
|
+ center_y = (y1 + y2) // 2
|
|
|
|
|
+ persons.append(TrackedPerson(
|
|
|
|
|
+ track_id=int(t.track_id),
|
|
|
|
|
+ bbox=(x1, y1, x2, y2),
|
|
|
|
|
+ center=(center_x, center_y),
|
|
|
|
|
+ confidence=float(t.score),
|
|
|
|
|
+ ))
|
|
|
|
|
+ return persons
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(f"BYTETracker 更新失败: {e},使用简化关联")
|
|
|
|
|
+ return self._simple_association(person_dets)
|
|
|
|
|
+
|
|
|
|
|
+ def _simple_association(self, detections: List) -> List[TrackedPerson]:
|
|
|
|
|
+ """简化关联:无 ID 复用,每次返回新 track_id"""
|
|
|
|
|
+ persons = []
|
|
|
|
|
+ for d in detections:
|
|
|
|
|
+ x1, y1, x2, y2 = d.bbox
|
|
|
|
|
+ center_x = (x1 + x2) // 2
|
|
|
|
|
+ center_y = (y1 + y2) // 2
|
|
|
|
|
+ persons.append(TrackedPerson(
|
|
|
|
|
+ track_id=-1,
|
|
|
|
|
+ bbox=(x1, y1, x2, y2),
|
|
|
|
|
+ center=(center_x, center_y),
|
|
|
|
|
+ confidence=d.confidence,
|
|
|
|
|
+ ))
|
|
|
|
|
+ return persons
|
|
|
|
|
+
|
|
|
|
|
+ def reset(self):
|
|
|
|
|
+ if self.model_type == "yolo" and self.model is not None:
|
|
|
|
|
+ self.model.predictor.trackers = []
|
|
|
|
|
+ if self.byte_tracker is not None:
|
|
|
|
|
+ self._init_byte_tracker()
|
|
|
|
|
+
|
|
|
|
|
+ def release(self):
|
|
|
|
|
+ if self.rknn_detector is not None:
|
|
|
|
|
+ self.rknn_detector.release()
|
|
|
|
|
+ self.rknn_detector = None
|
|
|
|
|
+ self.model = None
|
|
|
|
|
+ self.byte_tracker = None
|