Explorar el Código

feat: add UltralyticsTracker with YOLO/RKNN support

wenhongquan hace 1 día
padre
commit
ab19e8d1d8
Se han modificado 2 ficheros con 315 adiciones y 0 borrados
  1. 55 0
      dual_camera_system/tests/test_tracker.py
  2. 260 0
      dual_camera_system/tracker.py

+ 55 - 0
dual_camera_system/tests/test_tracker.py

@@ -0,0 +1,55 @@
+import sys
+import os
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+import numpy as np
+import pytest
+from tracker import UltralyticsTracker, TrackedPerson
+
+
+def test_tracked_person_dataclass():
+    p = TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(20, 30), confidence=0.9)
+    assert p.track_id == 1
+    assert p.class_name == "person"
+
+
+def test_tracker_yolo_path_returns_person(monkeypatch):
+    """YOLO 路径:本地若无 GPU/模型,mock model 行为验证过滤逻辑"""
+    tracker = UltralyticsTracker.__new__(UltralyticsTracker)
+    tracker.model_type = "yolo"
+    tracker.conf_threshold = 0.5
+    tracker.person_threshold = 0.5
+    tracker.use_gpu = False
+    tracker.tracker_type = "bytetrack"
+
+    class FakeBox:
+        cls = np.array([0.0])
+        conf = np.array([0.8])
+        xyxy = np.array([[10, 20, 30, 40]])
+        id = None
+
+        def __len__(self):
+            return len(self.cls)
+
+    class FakeResult:
+        names = {0: "person"}
+        boxes = FakeBox()
+
+    def fake_model(frame, **kwargs):
+        return [FakeResult()]
+
+    tracker.model = fake_model
+
+    frame = np.zeros((480, 640, 3), dtype=np.uint8)
+    results = tracker._detect_yolo(frame)
+    assert len(results) == 1
+    # 未经过 tracker 关联时,track_id 使用占位值 -1
+    assert results[0].track_id == -1
+
+
+def test_resolve_model_fallback():
+    from tracker import resolve_model
+    # 不存在的路径应回退到 yolo11n.pt
+    path, mtype = resolve_model("/not/exist/model.rknn", "auto")
+    assert path == "yolo11n.pt"
+    assert mtype == "yolo"

+ 260 - 0
dual_camera_system/tracker.py

@@ -0,0 +1,260 @@
+"""
+Ultralytics Tracker 封装
+支持 YOLO (.pt) 端到端跟踪 和 RKNN/ONNX 检测 + BYTETracker 关联
+"""
+
+import os
+from typing import List, Tuple, Optional
+from dataclasses import dataclass
+
+import numpy as np
+
+from config import TRACKING_CONFIG
+
+
+@dataclass
+class TrackedPerson:
+    """跟踪目标"""
+    track_id: int
+    bbox: Tuple[int, int, int, int]   # x1, y1, x2, y2
+    center: Tuple[int, int]
+    confidence: float
+    class_name: str = "person"
+    lost: bool = False
+
+
+def resolve_model(model_path: Optional[str], model_type: str) -> Tuple[str, str]:
+    """
+    解析模型路径和类型
+    优先级:model_path > TRACKING_CONFIG['fallback_model_path'] > yolo11n.pt 自动下载
+    """
+    if model_path and os.path.exists(model_path):
+        ext = os.path.splitext(model_path)[1].lower()
+        if ext == ".rknn":
+            return model_path, "rknn"
+        elif ext == ".onnx":
+            return model_path, "onnx"
+        elif ext == ".pt":
+            return model_path, "yolo"
+
+    # 尝试 fallback 路径
+    fallback = TRACKING_CONFIG.get("fallback_model_path")
+    if fallback and os.path.exists(fallback):
+        ext = os.path.splitext(fallback)[1].lower()
+        if ext == ".rknn":
+            return fallback, "rknn"
+        elif ext == ".onnx":
+            return fallback, "onnx"
+        return fallback, "yolo"
+
+    # 最终回退:Ultralytics 自动下载
+    return "yolo11n.pt", "yolo"
+
+
+class UltralyticsTracker:
+    """Ultralytics 跟踪器封装"""
+
+    def __init__(
+        self,
+        model_path: Optional[str] = None,
+        model_type: str = "auto",
+        use_gpu: bool = True,
+        tracker_type: str = "bytetrack",
+        conf_threshold: float = 0.5,
+        person_threshold: float = 0.5,
+        max_lost: int = 30,
+    ):
+        if model_path is None:
+            model_path = TRACKING_CONFIG["model_path"]
+
+        self.model_path = model_path
+        self.model_type = model_type
+        self.use_gpu = use_gpu
+        self.tracker_type = tracker_type
+        self.conf_threshold = conf_threshold
+        self.person_threshold = person_threshold
+        self.max_lost = max_lost
+
+        self.model = None
+        self.rknn_detector = None
+        self.byte_tracker = None
+
+        resolved_path, resolved_type = resolve_model(model_path, model_type)
+        self.model_path = resolved_path
+        self.model_type = resolved_type
+
+        self._load_model()
+
+    def _load_model(self):
+        if self.model_type == "rknn":
+            self._load_rknn_model()
+        elif self.model_type == "onnx":
+            self._load_onnx_model()
+        else:
+            self._load_yolo_model()
+
+    def _load_yolo_model(self):
+        from ultralytics import YOLO
+        self.model = YOLO(self.model_path)
+        dummy = np.zeros((640, 640, 3), dtype=np.uint8)
+        device = "cuda:0" if self.use_gpu else "cpu"
+        self.model(dummy, task="track", tracker=f"{self.tracker_type}.yaml", persist=True, verbose=False, device=device)
+        print(f"YOLO 跟踪模型加载成功: {self.model_path}")
+
+    def _load_rknn_model(self):
+        from safety_detector import RKNNDetector
+        self.rknn_detector = RKNNDetector(self.model_path)
+        self._init_byte_tracker()
+        print(f"RKNN 跟踪模型加载成功: {self.model_path}")
+
+    def _load_onnx_model(self):
+        from safety_detector import ONNXDetector
+        self.rknn_detector = ONNXDetector(self.model_path)
+        self._init_byte_tracker()
+        print(f"ONNX 跟踪模型加载成功: {self.model_path}")
+
+    def _init_byte_tracker(self):
+        try:
+            from ultralytics.trackers.byte_tracker import BYTETracker
+            self.byte_tracker = BYTETracker(args=self._tracker_args())
+        except Exception as e:
+            print(f"初始化 BYTETracker 失败: {e},将使用简化 IOU 关联")
+            self.byte_tracker = None
+
+    def _tracker_args(self):
+        class Args:
+            track_thresh = self.conf_threshold
+            match_thresh = 0.8
+            track_buffer = self.max_lost
+            mot20 = False
+        return Args()
+
+    def update(self, frame: np.ndarray) -> List[TrackedPerson]:
+        if frame is None:
+            return []
+        if self.model_type == "yolo":
+            return self._update_yolo(frame)
+        else:
+            return self._update_rknn_onnx(frame)
+
+    def _update_yolo(self, frame: np.ndarray) -> List[TrackedPerson]:
+        device = "cuda:0" if self.use_gpu else "cpu"
+        results = self.model(
+            frame,
+            task="track",
+            tracker=f"{self.tracker_type}.yaml",
+            persist=True,
+            conf=self.conf_threshold,
+            verbose=False,
+            device=device,
+        )
+        return self._parse_yolo_results(results, frame.shape)
+
+    def _detect_yolo(self, frame: np.ndarray) -> List[TrackedPerson]:
+        """仅供测试/mock 使用的 YOLO 检测入口,返回解析后的跟踪目标。"""
+        device = "cuda:0" if self.use_gpu else "cpu"
+        results = self.model(
+            frame,
+            task="track",
+            tracker=f"{self.tracker_type}.yaml",
+            persist=True,
+            conf=self.conf_threshold,
+            verbose=False,
+            device=device,
+        )
+        return self._parse_yolo_results(results, frame.shape)
+
+    def _parse_yolo_results(self, results, frame_shape) -> List[TrackedPerson]:
+        persons = []
+        h, w = frame_shape[:2]
+        for det in results:
+            boxes = det.boxes
+            if boxes is None or len(boxes) == 0:
+                continue
+            for i in range(len(boxes)):
+                cls_id = int(boxes.cls[i])
+                cls_name = det.names.get(cls_id, str(cls_id))
+                if cls_name != "person":
+                    continue
+                conf = float(boxes.conf[i])
+                if conf < self.person_threshold:
+                    continue
+                xyxy = boxes.xyxy[i]
+                if hasattr(xyxy, "cpu"):
+                    xyxy = xyxy.cpu().numpy()
+                x1, y1, x2, y2 = map(int, xyxy)
+                track_id = int(boxes.id[i]) if boxes.id is not None else -1
+                center_x = (x1 + x2) // 2
+                center_y = (y1 + y2) // 2
+                persons.append(TrackedPerson(
+                    track_id=track_id,
+                    bbox=(x1, y1, x2, y2),
+                    center=(center_x, center_y),
+                    confidence=conf,
+                ))
+        return persons
+
+    def _update_rknn_onnx(self, frame: np.ndarray) -> List[TrackedPerson]:
+        from safety_detector import Detection
+        conf_map = {3: self.person_threshold}
+        detections = self.rknn_detector.detect(frame, conf_map)
+        # 只保留 person
+        person_dets = [d for d in detections if d.class_id == 3]
+        if not person_dets:
+            return []
+
+        if self.byte_tracker is None:
+            return self._simple_association(person_dets)
+
+        # 构造 BYTETracker 输入 [x1, y1, x2, y2, conf, cls]
+        try:
+            import torch
+            dets = []
+            for d in person_dets:
+                x1, y1, x2, y2 = d.bbox
+                dets.append([x1, y1, x2, y2, d.confidence, d.class_id])
+            dets_t = torch.tensor(dets, dtype=torch.float32)
+            tracks = self.byte_tracker.update(dets_t, frame.shape)
+            persons = []
+            for t in tracks:
+                x1, y1, x2, y2 = map(int, t.tlbr)
+                center_x = (x1 + x2) // 2
+                center_y = (y1 + y2) // 2
+                persons.append(TrackedPerson(
+                    track_id=int(t.track_id),
+                    bbox=(x1, y1, x2, y2),
+                    center=(center_x, center_y),
+                    confidence=float(t.score),
+                ))
+            return persons
+        except Exception as e:
+            print(f"BYTETracker 更新失败: {e},使用简化关联")
+            return self._simple_association(person_dets)
+
+    def _simple_association(self, detections: List) -> List[TrackedPerson]:
+        """简化关联:无 ID 复用,每次返回新 track_id"""
+        persons = []
+        for d in detections:
+            x1, y1, x2, y2 = d.bbox
+            center_x = (x1 + x2) // 2
+            center_y = (y1 + y2) // 2
+            persons.append(TrackedPerson(
+                track_id=-1,
+                bbox=(x1, y1, x2, y2),
+                center=(center_x, center_y),
+                confidence=d.confidence,
+            ))
+        return persons
+
+    def reset(self):
+        if self.model_type == "yolo" and self.model is not None:
+            self.model.predictor.trackers = []
+        if self.byte_tracker is not None:
+            self._init_byte_tracker()
+
+    def release(self):
+        if self.rknn_detector is not None:
+            self.rknn_detector.release()
+            self.rknn_detector = None
+        self.model = None
+        self.byte_tracker = None