test_tracker.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import sys
  2. import os
  3. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  4. import logging
  5. import types
  6. import numpy as np
  7. import pytest
  8. import ultralytics
  9. import tracker
  10. from tracker import UltralyticsTracker, TrackedPerson, resolve_model
  11. class FakeBox:
  12. def __init__(self, cls, conf, xyxy, ids=None):
  13. self.cls = np.array(cls)
  14. self.conf = np.array(conf)
  15. self.xyxy = np.array(xyxy)
  16. self.id = np.array(ids) if ids is not None else None
  17. def __len__(self):
  18. return len(self.cls)
  19. class FakeResult:
  20. def __init__(self, names, boxes):
  21. self.names = names
  22. self.boxes = boxes
  23. @pytest.fixture
  24. def fake_yolo(monkeypatch):
  25. """返回一个函数,用于将 ultralytics.YOLO 替换为返回指定结果的假实现。"""
  26. def _make(results):
  27. class FakeYOLO:
  28. def __init__(self, path):
  29. self.path = path
  30. self.predictor = types.SimpleNamespace(trackers=[])
  31. self._results = results
  32. def __call__(self, frame, **kwargs):
  33. return self._results
  34. monkeypatch.setattr(ultralytics, 'YOLO', FakeYOLO)
  35. return _make
  36. def test_tracked_person_dataclass():
  37. p = TrackedPerson(track_id=1, bbox=(10, 20, 30, 40), center=(20, 30), confidence=0.9)
  38. assert p.track_id == 1
  39. assert p.class_name == "person"
  40. def test_resolve_model_fallback():
  41. # 不存在的路径应回退到默认 YOLO 模型
  42. path, mtype = resolve_model("/not/exist/model.rknn", "auto")
  43. assert path == "yolo11n.pt"
  44. assert mtype == "yolo"
  45. def test_resolve_model_auto_by_extension(tmp_path):
  46. # 创建空文件,仅用于路径/扩展名推断
  47. pt_path = tmp_path / "model.pt"
  48. rknn_path = tmp_path / "model.rknn"
  49. onnx_path = tmp_path / "model.onnx"
  50. pt_path.write_text("")
  51. rknn_path.write_text("")
  52. onnx_path.write_text("")
  53. assert resolve_model(str(pt_path), "auto") == (str(pt_path), "yolo")
  54. assert resolve_model(str(rknn_path), "auto") == (str(rknn_path), "rknn")
  55. assert resolve_model(str(onnx_path), "auto") == (str(onnx_path), "onnx")
  56. def test_resolve_model_respects_explicit_type(tmp_path):
  57. # 显式 model_type 优先于扩展名推断
  58. path = tmp_path / "weird.bin"
  59. path.write_text("")
  60. assert resolve_model(str(path), "rknn") == (str(path), "rknn")
  61. assert resolve_model(str(path), "onnx") == (str(path), "onnx")
  62. assert resolve_model(str(path), "yolo") == (str(path), "yolo")
  63. def test_update_filters_non_person(fake_yolo):
  64. fake_yolo([
  65. FakeResult(names={0: "car"}, boxes=FakeBox(cls=[0], conf=[0.9], xyxy=[[10, 20, 30, 40]]))
  66. ])
  67. tracker = UltralyticsTracker(model_path="/fake/yolo11n.pt", model_type="yolo")
  68. frame = np.zeros((480, 640, 3), dtype=np.uint8)
  69. results = tracker.update(frame)
  70. assert results == []
  71. def test_update_invalid_frame_returns_empty(fake_yolo):
  72. fake_yolo([
  73. FakeResult(names={0: "person"}, boxes=FakeBox(cls=[0], conf=[0.9], xyxy=[[10, 20, 30, 40]]))
  74. ])
  75. tracker = UltralyticsTracker(model_path="/fake/yolo11n.pt", model_type="yolo")
  76. assert tracker.update(None) == []
  77. def test_tracker_lifecycle(fake_yolo):
  78. fake_yolo([
  79. FakeResult(
  80. names={0: "person"},
  81. boxes=FakeBox(cls=[0], conf=[0.8], xyxy=[[10, 20, 30, 40]], ids=[42]),
  82. )
  83. ])
  84. tracker = UltralyticsTracker(model_path="/fake/yolo11n.pt", model_type="yolo", use_gpu=False)
  85. frame = np.zeros((480, 640, 3), dtype=np.uint8)
  86. results = tracker.update(frame)
  87. assert len(results) == 1
  88. assert results[0].track_id == 42
  89. assert results[0].bbox == (10, 20, 30, 40)
  90. tracker.reset()
  91. assert tracker.model.predictor.trackers == []
  92. tracker.release()
  93. assert tracker.model is None
  94. assert tracker.byte_tracker is None
  95. def test_rknn_import_fallback(fake_yolo, tmp_path, monkeypatch, caplog):
  96. rknn_path = tmp_path / "model.rknn"
  97. rknn_path.write_text("")
  98. class FailingRKNN:
  99. def __init__(self, path):
  100. raise ImportError("rknnlite not installed")
  101. monkeypatch.setattr("inference_backend.RKNNDetector", FailingRKNN)
  102. fake_yolo([
  103. FakeResult(
  104. names={0: "person"},
  105. boxes=FakeBox(cls=[0], conf=[0.8], xyxy=[[10, 20, 30, 40]]),
  106. )
  107. ])
  108. caplog.set_level(logging.WARNING, logger="tracker")
  109. tracker = UltralyticsTracker(model_path=str(rknn_path), model_type="rknn", use_gpu=False)
  110. assert tracker.model_type == "yolo"
  111. assert "RKNN 加载失败" in caplog.text
  112. frame = np.zeros((480, 640, 3), dtype=np.uint8)
  113. results = tracker.update(frame)
  114. assert len(results) == 1
  115. def test_onnx_import_fallback(fake_yolo, tmp_path, monkeypatch, caplog):
  116. onnx_path = tmp_path / "model.onnx"
  117. onnx_path.write_text("")
  118. class FailingONNX:
  119. def __init__(self, path):
  120. raise ImportError("onnxruntime not installed")
  121. monkeypatch.setattr("inference_backend.ONNXDetector", FailingONNX)
  122. fake_yolo([
  123. FakeResult(
  124. names={0: "person"},
  125. boxes=FakeBox(cls=[0], conf=[0.8], xyxy=[[10, 20, 30, 40]]),
  126. )
  127. ])
  128. caplog.set_level(logging.WARNING, logger="tracker")
  129. tracker = UltralyticsTracker(model_path=str(onnx_path), model_type="onnx", use_gpu=False)
  130. assert tracker.model_type == "yolo"
  131. assert "ONNX 加载失败" in caplog.text
  132. frame = np.zeros((480, 640, 3), dtype=np.uint8)
  133. results = tracker.update(frame)
  134. assert len(results) == 1