tracker.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. """
  2. Ultralytics Tracker 封装
  3. 支持 YOLO (.pt) 端到端跟踪 和 RKNN/ONNX 检测 + BYTETracker 关联
  4. """
  5. import logging
  6. import os
  7. import types
  8. from typing import Any, List, Tuple, Optional
  9. from dataclasses import dataclass
  10. import numpy as np
  11. from config import TRACKING_CONFIG
  12. from inference_backend import Detection
  13. logger = logging.getLogger(__name__)
  14. # Model type constants
  15. MODEL_TYPE_AUTO = "auto"
  16. MODEL_TYPE_RKNN = "rknn"
  17. MODEL_TYPE_ONNX = "onnx"
  18. MODEL_TYPE_YOLO = "yolo"
  19. # Default YOLO model used when no local model is found.
  20. # Ultralytics will automatically download the weights on first use.
  21. DEFAULT_YOLO_MODEL = "yolo11n.pt"
  22. @dataclass
  23. class TrackedPerson:
  24. """跟踪目标"""
  25. track_id: int
  26. bbox: Tuple[int, int, int, int] # x1, y1, x2, y2
  27. center: Tuple[int, int]
  28. confidence: float
  29. class_name: str = "person"
  30. lost: bool = False
  31. def resolve_model(model_path: Optional[str], model_type: str) -> Tuple[str, str]:
  32. """
  33. 解析模型路径和类型
  34. 优先级:
  35. 1. 显式 model_type(非 auto)优先于扩展名推断
  36. 2. model_path 存在时使用 model_path
  37. 3. 否则使用 TRACKING_CONFIG['fallback_model_path']
  38. 4. 最终回退到 Ultralytics 默认模型(自动下载)
  39. Args:
  40. model_path: 模型文件路径,可为 None
  41. model_type: 模型类型,'auto' 时根据扩展名推断,否则使用给定值
  42. Returns:
  43. (resolved_path, resolved_type)
  44. """
  45. def _infer_type(path: str) -> str:
  46. ext = os.path.splitext(path)[1].lower()
  47. if ext == ".rknn":
  48. return MODEL_TYPE_RKNN
  49. elif ext == ".onnx":
  50. return MODEL_TYPE_ONNX
  51. return MODEL_TYPE_YOLO
  52. # 1. 优先使用传入的 model_path
  53. if model_path and os.path.exists(model_path):
  54. resolved_type = _infer_type(model_path) if model_type == MODEL_TYPE_AUTO else model_type
  55. return model_path, resolved_type
  56. # 2. 回退到配置中的 fallback 路径
  57. fallback = TRACKING_CONFIG.get("fallback_model_path")
  58. if fallback and os.path.exists(fallback):
  59. resolved_type = _infer_type(fallback) if model_type == MODEL_TYPE_AUTO else model_type
  60. return fallback, resolved_type
  61. # 3. 最终回退:Ultralytics 自动下载
  62. return DEFAULT_YOLO_MODEL, MODEL_TYPE_YOLO
  63. class UltralyticsTracker:
  64. """Ultralytics 跟踪器封装
  65. 阈值说明:
  66. - conf_threshold: 调用模型/跟踪器时传入的检测置信度阈值,用于控制进入
  67. 跟踪流程的候选框数量。
  68. - person_threshold: 对检测到的 "person" 类别在解析结果时应用的过滤阈值,
  69. 仅保留置信度不低于该值的人员目标。
  70. """
  71. def __init__(
  72. self,
  73. model_path: Optional[str] = None,
  74. model_type: str = MODEL_TYPE_AUTO,
  75. use_gpu: bool = True,
  76. tracker_type: str = "bytetrack",
  77. conf_threshold: float = 0.5,
  78. person_threshold: float = 0.5,
  79. max_lost: int = 30,
  80. ):
  81. if model_path is None:
  82. model_path = TRACKING_CONFIG["model_path"]
  83. self.model_path = model_path
  84. self.model_type = model_type
  85. self.use_gpu = use_gpu
  86. self.tracker_type = tracker_type
  87. self.conf_threshold = conf_threshold
  88. self.person_threshold = person_threshold
  89. self.max_lost = max_lost
  90. self.model = None
  91. self.rknn_detector = None
  92. self.byte_tracker = None
  93. resolved_path, resolved_type = resolve_model(model_path, model_type)
  94. self.model_path = resolved_path
  95. self.model_type = resolved_type
  96. self._load_model()
  97. def _load_model(self) -> None:
  98. if self.model_type == MODEL_TYPE_RKNN:
  99. self._load_rknn_model()
  100. elif self.model_type == MODEL_TYPE_ONNX:
  101. self._load_onnx_model()
  102. else:
  103. self._load_yolo_model()
  104. def _load_yolo_model(self) -> None:
  105. from ultralytics import YOLO
  106. self.model = YOLO(self.model_path)
  107. dummy = np.zeros((640, 640, 3), dtype=np.uint8)
  108. device = "cuda:0" if self.use_gpu else "cpu"
  109. # Warmup / JIT:在空白图上执行一次跟踪,触发 ultralytics 内部
  110. # 的 tracker 初始化与可能的 PyTorch JIT 编译,避免首帧真实推理延迟。
  111. self.model(dummy, task="track", tracker=f"{self.tracker_type}.yaml", persist=True, verbose=False, device=device)
  112. logger.info("YOLO 跟踪模型加载成功: %s", self.model_path)
  113. def _load_rknn_model(self) -> None:
  114. try:
  115. from inference_backend import RKNNDetector
  116. self.rknn_detector = RKNNDetector(self.model_path)
  117. self._init_byte_tracker()
  118. logger.info("RKNN 跟踪模型加载成功: %s", self.model_path)
  119. except ImportError as e:
  120. logger.warning("RKNN 加载失败 (%s),回退到 YOLO 模型", e)
  121. self.model_type = MODEL_TYPE_YOLO
  122. self._load_yolo_model()
  123. def _load_onnx_model(self) -> None:
  124. try:
  125. from inference_backend import ONNXDetector
  126. self.rknn_detector = ONNXDetector(self.model_path)
  127. self._init_byte_tracker()
  128. logger.info("ONNX 跟踪模型加载成功: %s", self.model_path)
  129. except ImportError as e:
  130. logger.warning("ONNX 加载失败 (%s),回退到 YOLO 模型", e)
  131. self.model_type = MODEL_TYPE_YOLO
  132. self._load_yolo_model()
  133. def _init_byte_tracker(self) -> None:
  134. try:
  135. from ultralytics.trackers.byte_tracker import BYTETracker
  136. self.byte_tracker = BYTETracker(args=self._tracker_args())
  137. except Exception as e:
  138. logger.warning("初始化 BYTETracker 失败: %s,将使用简化 IOU 关联", e)
  139. self.byte_tracker = None
  140. def _tracker_args(self) -> types.SimpleNamespace:
  141. return types.SimpleNamespace(
  142. track_thresh=self.conf_threshold,
  143. match_thresh=0.8,
  144. track_buffer=self.max_lost,
  145. mot20=False,
  146. )
  147. def update(self, frame: Optional[np.ndarray]) -> List[TrackedPerson]:
  148. if frame is None:
  149. return []
  150. if self.model_type == MODEL_TYPE_YOLO:
  151. return self._update_yolo(frame)
  152. else:
  153. return self._update_rknn_onnx(frame)
  154. def _update_yolo(self, frame: np.ndarray) -> List[TrackedPerson]:
  155. device = "cuda:0" if self.use_gpu else "cpu"
  156. results = self.model(
  157. frame,
  158. task="track",
  159. tracker=f"{self.tracker_type}.yaml",
  160. persist=True,
  161. conf=self.conf_threshold,
  162. verbose=False,
  163. device=device,
  164. )
  165. return self._parse_yolo_results(results, frame.shape)
  166. def _parse_yolo_results(self, results: List[Any], frame_shape: Tuple[int, ...]) -> List[TrackedPerson]:
  167. persons = []
  168. h, w = frame_shape[:2]
  169. for det in results:
  170. boxes = det.boxes
  171. if boxes is None or len(boxes) == 0:
  172. continue
  173. for i in range(len(boxes)):
  174. cls_id = int(boxes.cls[i])
  175. cls_name = det.names.get(cls_id, str(cls_id))
  176. if cls_name != "person":
  177. continue
  178. conf = float(boxes.conf[i])
  179. if conf < self.person_threshold:
  180. continue
  181. xyxy = boxes.xyxy[i]
  182. if hasattr(xyxy, "cpu"):
  183. xyxy = xyxy.cpu().numpy()
  184. x1, y1, x2, y2 = map(int, xyxy)
  185. track_id = int(boxes.id[i]) if boxes.id is not None else -1
  186. center_x = (x1 + x2) // 2
  187. center_y = (y1 + y2) // 2
  188. persons.append(TrackedPerson(
  189. track_id=track_id,
  190. bbox=(x1, y1, x2, y2),
  191. center=(center_x, center_y),
  192. confidence=conf,
  193. ))
  194. return persons
  195. def _update_rknn_onnx(self, frame: np.ndarray) -> List[TrackedPerson]:
  196. conf_map = {3: self.person_threshold}
  197. detections = self.rknn_detector.detect(frame, conf_map)
  198. # 只保留 person
  199. person_dets = [d for d in detections if d.class_id == 3]
  200. if not person_dets:
  201. return []
  202. if self.byte_tracker is None:
  203. return self._simple_association(person_dets)
  204. # 构造 BYTETracker 输入 [x1, y1, x2, y2, conf, cls]
  205. try:
  206. import torch
  207. dets = []
  208. for d in person_dets:
  209. x1, y1, x2, y2 = d.bbox
  210. dets.append([x1, y1, x2, y2, d.confidence, d.class_id])
  211. dets_t = torch.tensor(dets, dtype=torch.float32)
  212. tracks = self.byte_tracker.update(dets_t, frame.shape)
  213. persons = []
  214. for t in tracks:
  215. x1, y1, x2, y2 = map(int, t.tlbr)
  216. center_x = (x1 + x2) // 2
  217. center_y = (y1 + y2) // 2
  218. persons.append(TrackedPerson(
  219. track_id=int(t.track_id),
  220. bbox=(x1, y1, x2, y2),
  221. center=(center_x, center_y),
  222. confidence=float(t.score),
  223. ))
  224. return persons
  225. except Exception as e:
  226. logger.warning("BYTETracker 更新失败: %s,使用简化关联", e)
  227. return self._simple_association(person_dets)
  228. def _simple_association(self, detections: List[Detection]) -> List[TrackedPerson]:
  229. """简化关联:无 ID 复用,每次返回新 track_id"""
  230. persons = []
  231. for d in detections:
  232. x1, y1, x2, y2 = d.bbox
  233. center_x = (x1 + x2) // 2
  234. center_y = (y1 + y2) // 2
  235. persons.append(TrackedPerson(
  236. track_id=-1,
  237. bbox=(x1, y1, x2, y2),
  238. center=(center_x, center_y),
  239. confidence=d.confidence,
  240. ))
  241. return persons
  242. def reset(self) -> None:
  243. if self.model_type == MODEL_TYPE_YOLO and self.model is not None:
  244. self.model.predictor.trackers = []
  245. if self.byte_tracker is not None:
  246. self._init_byte_tracker()
  247. def release(self) -> None:
  248. if self.rknn_detector is not None:
  249. self.rknn_detector.release()
  250. self.rknn_detector = None
  251. self.model = None
  252. self.byte_tracker = None