| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- import sys
- import os
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- import logging
- import types
- import numpy as np
- import pytest
- import ultralytics
- import tracker
- from tracker import UltralyticsTracker, TrackedPerson, resolve_model
- class FakeBox:
- def __init__(self, cls, conf, xyxy, ids=None):
- self.cls = np.array(cls)
- self.conf = np.array(conf)
- self.xyxy = np.array(xyxy)
- self.id = np.array(ids) if ids is not None else None
- def __len__(self):
- return len(self.cls)
- class FakeResult:
- def __init__(self, names, boxes):
- self.names = names
- self.boxes = boxes
- @pytest.fixture
- def fake_yolo(monkeypatch):
- """返回一个函数,用于将 ultralytics.YOLO 替换为返回指定结果的假实现。"""
- def _make(results):
- class FakeYOLO:
- def __init__(self, path):
- self.path = path
- self.predictor = types.SimpleNamespace(trackers=[])
- self._results = results
- def __call__(self, frame, **kwargs):
- return self._results
- monkeypatch.setattr(ultralytics, 'YOLO', FakeYOLO)
- return _make
- def test_tracked_person_dataclass():
- p = TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(20, 30), confidence=0.9)
- assert p.track_id == 1
- assert p.class_name == "person"
- def test_resolve_model_fallback():
- # 不存在的路径应回退到默认 YOLO 模型
- path, mtype = resolve_model("/not/exist/model.rknn", "auto")
- assert path == "yolo11n.pt"
- assert mtype == "yolo"
- def test_resolve_model_auto_by_extension(tmp_path):
- # 创建空文件,仅用于路径/扩展名推断
- pt_path = tmp_path / "model.pt"
- rknn_path = tmp_path / "model.rknn"
- onnx_path = tmp_path / "model.onnx"
- pt_path.write_text("")
- rknn_path.write_text("")
- onnx_path.write_text("")
- assert resolve_model(str(pt_path), "auto") == (str(pt_path), "yolo")
- assert resolve_model(str(rknn_path), "auto") == (str(rknn_path), "rknn")
- assert resolve_model(str(onnx_path), "auto") == (str(onnx_path), "onnx")
- def test_resolve_model_respects_explicit_type(tmp_path):
- # 显式 model_type 优先于扩展名推断
- path = tmp_path / "weird.bin"
- path.write_text("")
- assert resolve_model(str(path), "rknn") == (str(path), "rknn")
- assert resolve_model(str(path), "onnx") == (str(path), "onnx")
- assert resolve_model(str(path), "yolo") == (str(path), "yolo")
- def test_update_filters_non_person(fake_yolo):
- fake_yolo([
- FakeResult(names={0: "car"}, boxes=FakeBox(cls=[0], conf=[0.9], xyxy=[[10, 20, 30, 40]]))
- ])
- tracker = UltralyticsTracker(model_path="/fake/yolo11n.pt", model_type="yolo")
- frame = np.zeros((480, 640, 3), dtype=np.uint8)
- results = tracker.update(frame)
- assert results == []
- def test_update_invalid_frame_returns_empty(fake_yolo):
- fake_yolo([
- FakeResult(names={0: "person"}, boxes=FakeBox(cls=[0], conf=[0.9], xyxy=[[10, 20, 30, 40]]))
- ])
- tracker = UltralyticsTracker(model_path="/fake/yolo11n.pt", model_type="yolo")
- assert tracker.update(None) == []
- def test_tracker_lifecycle(fake_yolo):
- fake_yolo([
- FakeResult(
- names={0: "person"},
- boxes=FakeBox(cls=[0], conf=[0.8], xyxy=[[10, 20, 30, 40]], ids=[42]),
- )
- ])
- tracker = UltralyticsTracker(model_path="/fake/yolo11n.pt", model_type="yolo", use_gpu=False)
- frame = np.zeros((480, 640, 3), dtype=np.uint8)
- results = tracker.update(frame)
- assert len(results) == 1
- assert results[0].track_id == 42
- assert results[0].bbox == (10, 20, 30, 40)
- tracker.reset()
- assert tracker.model.predictor.trackers == []
- tracker.release()
- assert tracker.model is None
- assert tracker.byte_tracker is None
- def test_rknn_import_fallback(fake_yolo, tmp_path, monkeypatch, caplog):
- rknn_path = tmp_path / "model.rknn"
- rknn_path.write_text("")
- class FailingRKNN:
- def __init__(self, path):
- raise ImportError("rknnlite not installed")
- monkeypatch.setattr("inference_backend.RKNNDetector", FailingRKNN)
- fake_yolo([
- FakeResult(
- names={0: "person"},
- boxes=FakeBox(cls=[0], conf=[0.8], xyxy=[[10, 20, 30, 40]]),
- )
- ])
- caplog.set_level(logging.WARNING, logger="tracker")
- tracker = UltralyticsTracker(model_path=str(rknn_path), model_type="rknn", use_gpu=False)
- assert tracker.model_type == "yolo"
- assert "RKNN 加载失败" in caplog.text
- frame = np.zeros((480, 640, 3), dtype=np.uint8)
- results = tracker.update(frame)
- assert len(results) == 1
- def test_onnx_import_fallback(fake_yolo, tmp_path, monkeypatch, caplog):
- onnx_path = tmp_path / "model.onnx"
- onnx_path.write_text("")
- class FailingONNX:
- def __init__(self, path):
- raise ImportError("onnxruntime not installed")
- monkeypatch.setattr("inference_backend.ONNXDetector", FailingONNX)
- fake_yolo([
- FakeResult(
- names={0: "person"},
- boxes=FakeBox(cls=[0], conf=[0.8], xyxy=[[10, 20, 30, 40]]),
- )
- ])
- caplog.set_level(logging.WARNING, logger="tracker")
- tracker = UltralyticsTracker(model_path=str(onnx_path), model_type="onnx", use_gpu=False)
- assert tracker.model_type == "yolo"
- assert "ONNX 加载失败" in caplog.text
- frame = np.zeros((480, 640, 3), dtype=np.uint8)
- results = tracker.update(frame)
- assert len(results) == 1
|