Просмотр исходного кода

refactor(tracker): handle import fallbacks, respect model_type, expand tests

wenhongquan 1 день назад
Родитель
Сommit
398078d787
2 измененных файлов с 255 добавлено и 97 удалено
  1. 155 31
      dual_camera_system/tests/test_tracker.py
  2. 100 66
      dual_camera_system/tracker.py

+ 155 - 31
dual_camera_system/tests/test_tracker.py

@@ -2,9 +2,50 @@ import sys
 import os
 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 
+import logging
+import types
+
 import numpy as np
 import pytest
-from tracker import UltralyticsTracker, TrackedPerson
+
+import ultralytics
+import tracker
+from tracker import UltralyticsTracker, TrackedPerson, resolve_model
+
+
+class FakeBox:
+    def __init__(self, cls, conf, xyxy, ids=None):
+        self.cls = np.array(cls)
+        self.conf = np.array(conf)
+        self.xyxy = np.array(xyxy)
+        self.id = np.array(ids) if ids is not None else None
+
+    def __len__(self):
+        return len(self.cls)
+
+
+class FakeResult:
+    def __init__(self, names, boxes):
+        self.names = names
+        self.boxes = boxes
+
+
+@pytest.fixture
+def fake_yolo(monkeypatch):
+    """返回一个函数,用于将 ultralytics.YOLO 替换为返回指定结果的假实现。"""
+    def _make(results):
+        class FakeYOLO:
+            def __init__(self, path):
+                self.path = path
+                self.predictor = types.SimpleNamespace(trackers=[])
+                self._results = results
+
+            def __call__(self, frame, **kwargs):
+                return self._results
+
+        monkeypatch.setattr(ultralytics, 'YOLO', FakeYOLO)
+
+    return _make
 
 
 def test_tracked_person_dataclass():
@@ -13,43 +54,126 @@ def test_tracked_person_dataclass():
     assert p.class_name == "person"
 
 
-def test_tracker_yolo_path_returns_person(monkeypatch):
-    """YOLO 路径:本地若无 GPU/模型,mock model 行为验证过滤逻辑"""
-    tracker = UltralyticsTracker.__new__(UltralyticsTracker)
-    tracker.model_type = "yolo"
-    tracker.conf_threshold = 0.5
-    tracker.person_threshold = 0.5
-    tracker.use_gpu = False
-    tracker.tracker_type = "bytetrack"
+def test_resolve_model_fallback():
+    # 不存在的路径应回退到默认 YOLO 模型
+    path, mtype = resolve_model("/not/exist/model.rknn", "auto")
+    assert path == "yolo11n.pt"
+    assert mtype == "yolo"
+
 
-    class FakeBox:
-        cls = np.array([0.0])
-        conf = np.array([0.8])
-        xyxy = np.array([[10, 20, 30, 40]])
-        id = None
+def test_resolve_model_auto_by_extension(tmp_path):
+    # 创建空文件,仅用于路径/扩展名推断
+    pt_path = tmp_path / "model.pt"
+    rknn_path = tmp_path / "model.rknn"
+    onnx_path = tmp_path / "model.onnx"
+    pt_path.write_text("")
+    rknn_path.write_text("")
+    onnx_path.write_text("")
 
-        def __len__(self):
-            return len(self.cls)
+    assert resolve_model(str(pt_path), "auto") == (str(pt_path), "yolo")
+    assert resolve_model(str(rknn_path), "auto") == (str(rknn_path), "rknn")
+    assert resolve_model(str(onnx_path), "auto") == (str(onnx_path), "onnx")
 
-    class FakeResult:
-        names = {0: "person"}
-        boxes = FakeBox()
 
-    def fake_model(frame, **kwargs):
-        return [FakeResult()]
+def test_resolve_model_respects_explicit_type(tmp_path):
+    # 显式 model_type 优先于扩展名推断
+    path = tmp_path / "weird.bin"
+    path.write_text("")
+    assert resolve_model(str(path), "rknn") == (str(path), "rknn")
+    assert resolve_model(str(path), "onnx") == (str(path), "onnx")
+    assert resolve_model(str(path), "yolo") == (str(path), "yolo")
 
-    tracker.model = fake_model
 
+def test_update_filters_non_person(fake_yolo):
+    fake_yolo([
+        FakeResult(names={0: "car"}, boxes=FakeBox(cls=[0], conf=[0.9], xyxy=[[10, 20, 30, 40]]))
+    ])
+    tracker = UltralyticsTracker(model_path="/fake/yolo11n.pt", model_type="yolo")
     frame = np.zeros((480, 640, 3), dtype=np.uint8)
-    results = tracker._detect_yolo(frame)
+    results = tracker.update(frame)
+    assert results == []
+
+
+def test_update_invalid_frame_returns_empty(fake_yolo):
+    fake_yolo([
+        FakeResult(names={0: "person"}, boxes=FakeBox(cls=[0], conf=[0.9], xyxy=[[10, 20, 30, 40]]))
+    ])
+    tracker = UltralyticsTracker(model_path="/fake/yolo11n.pt", model_type="yolo")
+    assert tracker.update(None) == []
+
+
+def test_tracker_lifecycle(fake_yolo):
+    fake_yolo([
+        FakeResult(
+            names={0: "person"},
+            boxes=FakeBox(cls=[0], conf=[0.8], xyxy=[[10, 20, 30, 40]], ids=[42]),
+        )
+    ])
+    tracker = UltralyticsTracker(model_path="/fake/yolo11n.pt", model_type="yolo", use_gpu=False)
+    frame = np.zeros((480, 640, 3), dtype=np.uint8)
+
+    results = tracker.update(frame)
     assert len(results) == 1
-    # 未经过 tracker 关联时,track_id 使用占位值 -1
-    assert results[0].track_id == -1
+    assert results[0].track_id == 42
+    assert results[0].bbox == (10, 20, 30, 40)
 
+    tracker.reset()
+    assert tracker.model.predictor.trackers == []
 
-def test_resolve_model_fallback():
-    from tracker import resolve_model
-    # 不存在的路径应回退到 yolo11n.pt
-    path, mtype = resolve_model("/not/exist/model.rknn", "auto")
-    assert path == "yolo11n.pt"
-    assert mtype == "yolo"
+    tracker.release()
+    assert tracker.model is None
+    assert tracker.byte_tracker is None
+
+
+def test_rknn_import_fallback(fake_yolo, tmp_path, monkeypatch, caplog):
+    rknn_path = tmp_path / "model.rknn"
+    rknn_path.write_text("")
+
+    class FailingRKNN:
+        def __init__(self, path):
+            raise ImportError("rknnlite not installed")
+
+    monkeypatch.setattr("safety_detector.RKNNDetector", FailingRKNN)
+    fake_yolo([
+        FakeResult(
+            names={0: "person"},
+            boxes=FakeBox(cls=[0], conf=[0.8], xyxy=[[10, 20, 30, 40]]),
+        )
+    ])
+
+    caplog.set_level(logging.WARNING, logger="tracker")
+    tracker = UltralyticsTracker(model_path=str(rknn_path), model_type="rknn", use_gpu=False)
+
+    assert tracker.model_type == "yolo"
+    assert "RKNN 加载失败" in caplog.text
+
+    frame = np.zeros((480, 640, 3), dtype=np.uint8)
+    results = tracker.update(frame)
+    assert len(results) == 1
+
+
+def test_onnx_import_fallback(fake_yolo, tmp_path, monkeypatch, caplog):
+    onnx_path = tmp_path / "model.onnx"
+    onnx_path.write_text("")
+
+    class FailingONNX:
+        def __init__(self, path):
+            raise ImportError("onnxruntime not installed")
+
+    monkeypatch.setattr("safety_detector.ONNXDetector", FailingONNX)
+    fake_yolo([
+        FakeResult(
+            names={0: "person"},
+            boxes=FakeBox(cls=[0], conf=[0.8], xyxy=[[10, 20, 30, 40]]),
+        )
+    ])
+
+    caplog.set_level(logging.WARNING, logger="tracker")
+    tracker = UltralyticsTracker(model_path=str(onnx_path), model_type="onnx", use_gpu=False)
+
+    assert tracker.model_type == "yolo"
+    assert "ONNX 加载失败" in caplog.text
+
+    frame = np.zeros((480, 640, 3), dtype=np.uint8)
+    results = tracker.update(frame)
+    assert len(results) == 1

+ 100 - 66
dual_camera_system/tracker.py

@@ -3,13 +3,30 @@ Ultralytics Tracker 封装
 支持 YOLO (.pt) 端到端跟踪 和 RKNN/ONNX 检测 + BYTETracker 关联
 """
 
+import logging
 import os
-from typing import List, Tuple, Optional
+import types
+from typing import Any, List, Tuple, Optional
 from dataclasses import dataclass
 
 import numpy as np
 
 from config import TRACKING_CONFIG
+from safety_detector import Detection
+
+
+logger = logging.getLogger(__name__)
+
+
+# Model type constants
+MODEL_TYPE_AUTO = "auto"
+MODEL_TYPE_RKNN = "rknn"
+MODEL_TYPE_ONNX = "onnx"
+MODEL_TYPE_YOLO = "yolo"
+
+# Default YOLO model used when no local model is found.
+# Ultralytics will automatically download the weights on first use.
+DEFAULT_YOLO_MODEL = "yolo11n.pt"
 
 
 @dataclass
@@ -26,38 +43,58 @@ class TrackedPerson:
 def resolve_model(model_path: Optional[str], model_type: str) -> Tuple[str, str]:
     """
     解析模型路径和类型
-    优先级:model_path > TRACKING_CONFIG['fallback_model_path'] > yolo11n.pt 自动下载
+
+    优先级:
+    1. 显式 model_type(非 auto)优先于扩展名推断
+    2. model_path 存在时使用 model_path
+    3. 否则使用 TRACKING_CONFIG['fallback_model_path']
+    4. 最终回退到 Ultralytics 默认模型(自动下载)
+
+    Args:
+        model_path: 模型文件路径,可为 None
+        model_type: 模型类型,'auto' 时根据扩展名推断,否则使用给定值
+
+    Returns:
+        (resolved_path, resolved_type)
     """
-    if model_path and os.path.exists(model_path):
-        ext = os.path.splitext(model_path)[1].lower()
+
+    def _infer_type(path: str) -> str:
+        ext = os.path.splitext(path)[1].lower()
         if ext == ".rknn":
-            return model_path, "rknn"
+            return MODEL_TYPE_RKNN
         elif ext == ".onnx":
-            return model_path, "onnx"
-        elif ext == ".pt":
-            return model_path, "yolo"
+            return MODEL_TYPE_ONNX
+        return MODEL_TYPE_YOLO
+
+    # 1. 优先使用传入的 model_path
+    if model_path and os.path.exists(model_path):
+        resolved_type = _infer_type(model_path) if model_type == MODEL_TYPE_AUTO else model_type
+        return model_path, resolved_type
 
-    # 尝试 fallback 路径
+    # 2. 回退到配置中的 fallback 路径
     fallback = TRACKING_CONFIG.get("fallback_model_path")
     if fallback and os.path.exists(fallback):
-        ext = os.path.splitext(fallback)[1].lower()
-        if ext == ".rknn":
-            return fallback, "rknn"
-        elif ext == ".onnx":
-            return fallback, "onnx"
-        return fallback, "yolo"
+        resolved_type = _infer_type(fallback) if model_type == MODEL_TYPE_AUTO else model_type
+        return fallback, resolved_type
 
-    # 最终回退:Ultralytics 自动下载
-    return "yolo11n.pt", "yolo"
+    # 3. 最终回退:Ultralytics 自动下载
+    return DEFAULT_YOLO_MODEL, MODEL_TYPE_YOLO
 
 
 class UltralyticsTracker:
-    """Ultralytics 跟踪器封装"""
+    """Ultralytics 跟踪器封装
+
+    阈值说明:
+    - conf_threshold: 调用模型/跟踪器时传入的检测置信度阈值,用于控制进入
+      跟踪流程的候选框数量。
+    - person_threshold: 对检测到的 "person" 类别在解析结果时应用的过滤阈值,
+      仅保留置信度不低于该值的人员目标。
+    """
 
     def __init__(
         self,
         model_path: Optional[str] = None,
-        model_type: str = "auto",
+        model_type: str = MODEL_TYPE_AUTO,
         use_gpu: bool = True,
         tracker_type: str = "bytetrack",
         conf_threshold: float = 0.5,
@@ -85,54 +122,66 @@ class UltralyticsTracker:
 
         self._load_model()
 
-    def _load_model(self):
-        if self.model_type == "rknn":
+    def _load_model(self) -> None:
+        if self.model_type == MODEL_TYPE_RKNN:
             self._load_rknn_model()
-        elif self.model_type == "onnx":
+        elif self.model_type == MODEL_TYPE_ONNX:
             self._load_onnx_model()
         else:
             self._load_yolo_model()
 
-    def _load_yolo_model(self):
+    def _load_yolo_model(self) -> None:
         from ultralytics import YOLO
         self.model = YOLO(self.model_path)
         dummy = np.zeros((640, 640, 3), dtype=np.uint8)
         device = "cuda:0" if self.use_gpu else "cpu"
+        # Warmup / JIT:在空白图上执行一次跟踪,触发 ultralytics 内部
+        # 的 tracker 初始化与可能的 PyTorch JIT 编译,避免首帧真实推理延迟。
         self.model(dummy, task="track", tracker=f"{self.tracker_type}.yaml", persist=True, verbose=False, device=device)
-        print(f"YOLO 跟踪模型加载成功: {self.model_path}")
+        logger.info("YOLO 跟踪模型加载成功: %s", self.model_path)
 
-    def _load_rknn_model(self):
-        from safety_detector import RKNNDetector
-        self.rknn_detector = RKNNDetector(self.model_path)
-        self._init_byte_tracker()
-        print(f"RKNN 跟踪模型加载成功: {self.model_path}")
+    def _load_rknn_model(self) -> None:
+        try:
+            from safety_detector import RKNNDetector
+            self.rknn_detector = RKNNDetector(self.model_path)
+            self._init_byte_tracker()
+            logger.info("RKNN 跟踪模型加载成功: %s", self.model_path)
+        except ImportError as e:
+            logger.warning("RKNN 加载失败 (%s),回退到 YOLO 模型", e)
+            self.model_type = MODEL_TYPE_YOLO
+            self._load_yolo_model()
 
-    def _load_onnx_model(self):
-        from safety_detector import ONNXDetector
-        self.rknn_detector = ONNXDetector(self.model_path)
-        self._init_byte_tracker()
-        print(f"ONNX 跟踪模型加载成功: {self.model_path}")
+    def _load_onnx_model(self) -> None:
+        try:
+            from safety_detector import ONNXDetector
+            self.rknn_detector = ONNXDetector(self.model_path)
+            self._init_byte_tracker()
+            logger.info("ONNX 跟踪模型加载成功: %s", self.model_path)
+        except ImportError as e:
+            logger.warning("ONNX 加载失败 (%s),回退到 YOLO 模型", e)
+            self.model_type = MODEL_TYPE_YOLO
+            self._load_yolo_model()
 
-    def _init_byte_tracker(self):
+    def _init_byte_tracker(self) -> None:
         try:
             from ultralytics.trackers.byte_tracker import BYTETracker
             self.byte_tracker = BYTETracker(args=self._tracker_args())
         except Exception as e:
-            print(f"初始化 BYTETracker 失败: {e},将使用简化 IOU 关联")
+            logger.warning("初始化 BYTETracker 失败: %s,将使用简化 IOU 关联", e)
             self.byte_tracker = None
 
-    def _tracker_args(self):
-        class Args:
-            track_thresh = self.conf_threshold
-            match_thresh = 0.8
-            track_buffer = self.max_lost
-            mot20 = False
-        return Args()
+    def _tracker_args(self) -> types.SimpleNamespace:
+        return types.SimpleNamespace(
+            track_thresh=self.conf_threshold,
+            match_thresh=0.8,
+            track_buffer=self.max_lost,
+            mot20=False,
+        )
 
-    def update(self, frame: np.ndarray) -> List[TrackedPerson]:
+    def update(self, frame: Optional[np.ndarray]) -> List[TrackedPerson]:
         if frame is None:
             return []
-        if self.model_type == "yolo":
+        if self.model_type == MODEL_TYPE_YOLO:
             return self._update_yolo(frame)
         else:
             return self._update_rknn_onnx(frame)
@@ -150,21 +199,7 @@ class UltralyticsTracker:
         )
         return self._parse_yolo_results(results, frame.shape)
 
-    def _detect_yolo(self, frame: np.ndarray) -> List[TrackedPerson]:
-        """仅供测试/mock 使用的 YOLO 检测入口,返回解析后的跟踪目标。"""
-        device = "cuda:0" if self.use_gpu else "cpu"
-        results = self.model(
-            frame,
-            task="track",
-            tracker=f"{self.tracker_type}.yaml",
-            persist=True,
-            conf=self.conf_threshold,
-            verbose=False,
-            device=device,
-        )
-        return self._parse_yolo_results(results, frame.shape)
-
-    def _parse_yolo_results(self, results, frame_shape) -> List[TrackedPerson]:
+    def _parse_yolo_results(self, results: List[Any], frame_shape: Tuple[int, ...]) -> List[TrackedPerson]:
         persons = []
         h, w = frame_shape[:2]
         for det in results:
@@ -195,7 +230,6 @@ class UltralyticsTracker:
         return persons
 
     def _update_rknn_onnx(self, frame: np.ndarray) -> List[TrackedPerson]:
-        from safety_detector import Detection
         conf_map = {3: self.person_threshold}
         detections = self.rknn_detector.detect(frame, conf_map)
         # 只保留 person
@@ -228,10 +262,10 @@ class UltralyticsTracker:
                 ))
             return persons
         except Exception as e:
-            print(f"BYTETracker 更新失败: {e},使用简化关联")
+            logger.warning("BYTETracker 更新失败: %s,使用简化关联", e)
             return self._simple_association(person_dets)
 
-    def _simple_association(self, detections: List) -> List[TrackedPerson]:
+    def _simple_association(self, detections: List[Detection]) -> List[TrackedPerson]:
         """简化关联:无 ID 复用,每次返回新 track_id"""
         persons = []
         for d in detections:
@@ -246,13 +280,13 @@ class UltralyticsTracker:
             ))
         return persons
 
-    def reset(self):
-        if self.model_type == "yolo" and self.model is not None:
+    def reset(self) -> None:
+        if self.model_type == MODEL_TYPE_YOLO and self.model is not None:
             self.model.predictor.trackers = []
         if self.byte_tracker is not None:
             self._init_byte_tracker()
 
-    def release(self):
+    def release(self) -> None:
         if self.rknn_detector is not None:
             self.rknn_detector.release()
             self.rknn_detector = None