inference_backend.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. """
  2. 通用推理后端
  3. 为 UltralyticsTracker 提供 RKNN / ONNX 模型的统一检测接口,
  4. 与安全检测(安全帽/反光衣)解耦。
  5. """
  6. import os
  7. import cv2
  8. import numpy as np
  9. from typing import List, Tuple, Dict, Any
  10. from dataclasses import dataclass
  11. @dataclass
  12. class Detection:
  13. """检测结果 (用于 RKNN/ONNX 模型)"""
  14. class_id: int
  15. class_name: str
  16. confidence: float
  17. bbox: Tuple[int, int, int, int]
  18. def nms(dets, iou_threshold=0.45):
  19. """非极大值抑制"""
  20. if len(dets) == 0:
  21. return []
  22. boxes = np.array([[d.bbox[0], d.bbox[1], d.bbox[2], d.bbox[3], d.confidence] for d in dets])
  23. x1 = boxes[:, 0]
  24. y1 = boxes[:, 1]
  25. x2 = boxes[:, 2]
  26. y2 = boxes[:, 3]
  27. scores = boxes[:, 4]
  28. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  29. order = scores.argsort()[::-1]
  30. keep = []
  31. while order.size > 0:
  32. i = order[0]
  33. keep.append(i)
  34. xx1 = np.maximum(x1[i], x1[order[1:]])
  35. yy1 = np.maximum(y1[i], y1[order[1:]])
  36. xx2 = np.minimum(x2[i], x2[order[1:]])
  37. yy2 = np.minimum(y2[i], y2[order[1:]])
  38. w = np.maximum(0.0, xx2 - xx1 + 1)
  39. h = np.maximum(0.0, yy2 - yy1 + 1)
  40. inter = w * h
  41. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  42. inds = np.where(ovr <= iou_threshold)[0]
  43. order = order[inds + 1]
  44. return [dets[i] for i in keep]
  45. class BaseDetector:
  46. """检测器基类 (用于 RKNN/ONNX 模型)"""
  47. # 默认 COCO 类别映射;子类可覆盖
  48. LABEL_MAP = {0: 'person'}
  49. def __init__(self, label_map: Dict[int, str] = None):
  50. self.input_size = (640, 640)
  51. self.num_classes = len(label_map) if label_map else max(self.LABEL_MAP.keys()) + 1
  52. if label_map:
  53. self.LABEL_MAP = label_map
  54. def letterbox(self, image):
  55. """Letterbox 预处理,保持宽高比"""
  56. h0, w0 = image.shape[:2]
  57. ih, iw = self.input_size
  58. scale = min(iw / w0, ih / h0)
  59. new_w, new_h = int(w0 * scale), int(h0 * scale)
  60. pad_w = (iw - new_w) // 2
  61. pad_h = (ih - new_h) // 2
  62. resized = cv2.resize(image, (new_w, new_h))
  63. canvas = np.full((ih, iw, 3), 114, dtype=np.uint8)
  64. canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
  65. return canvas, scale, pad_w, pad_h, h0, w0
  66. def postprocess(self, outputs, scale, pad_w, pad_h, h0, w0, conf_threshold_map):
  67. """后处理"""
  68. dets = []
  69. if not outputs:
  70. return dets
  71. output = outputs[0]
  72. if len(output.shape) == 3:
  73. output = output[0]
  74. num_boxes = output.shape[1]
  75. for i in range(num_boxes):
  76. x_center = float(output[0, i])
  77. y_center = float(output[1, i])
  78. width = float(output[2, i])
  79. height = float(output[3, i])
  80. class_probs = output[4:4+self.num_classes, i]
  81. best_class = int(np.argmax(class_probs))
  82. confidence = float(class_probs[best_class])
  83. if best_class not in self.LABEL_MAP:
  84. continue
  85. conf_threshold = conf_threshold_map.get(best_class, 0.5)
  86. if confidence < conf_threshold:
  87. continue
  88. # 移除 padding 并缩放到原始图像尺寸
  89. x1 = int(((x_center - width / 2) - pad_w) / scale)
  90. y1 = int(((y_center - height / 2) - pad_h) / scale)
  91. x2 = int(((x_center + width / 2) - pad_w) / scale)
  92. y2 = int(((y_center + height / 2) - pad_h) / scale)
  93. x1 = max(0, min(w0, x1))
  94. y1 = max(0, min(h0, y1))
  95. x2 = max(0, min(w0, x2))
  96. y2 = max(0, min(h0, y2))
  97. det = Detection(
  98. class_id=best_class,
  99. class_name=self.LABEL_MAP[best_class],
  100. confidence=confidence,
  101. bbox=(x1, y1, x2, y2)
  102. )
  103. dets.append(det)
  104. dets = nms(dets, iou_threshold=0.45)
  105. return dets
  106. def detect(self, image, conf_threshold_map):
  107. raise NotImplementedError
  108. def release(self):
  109. pass
  110. class RKNNDetector(BaseDetector):
  111. """RKNN 检测器 - 使用 NHWC 输入格式 (1, H, W, C)"""
  112. def __init__(self, model_path: str, label_map: Dict[int, str] = None):
  113. super().__init__(label_map=label_map)
  114. self.model_path = model_path
  115. self.rknn = None
  116. try:
  117. from rknnlite.api import RKNNLite
  118. self.rknn = RKNNLite()
  119. except ImportError:
  120. raise ImportError("未安装 rknnlite,请运行: pip install rknnlite2 或参考 testrk3588/setup_rknn.sh")
  121. ret = self.rknn.load_rknn(model_path)
  122. if ret != 0:
  123. raise RuntimeError(f"加载 RKNN 模型失败: {model_path}")
  124. ret = self.rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0_1_2)
  125. if ret != 0:
  126. raise RuntimeError("初始化 RKNN 运行时失败")
  127. print(f"RKNN 模型加载成功: {model_path}")
  128. def detect(self, image, conf_threshold_map):
  129. canvas, scale, pad_w, pad_h, h0, w0 = self.letterbox(image)
  130. # RKNN 期望 NHWC (1, H, W, C), RGB, 归一化 0-1
  131. img = canvas[..., ::-1].astype(np.float32) / 255.0
  132. blob = img[None, ...] # (1, 640, 640, 3)
  133. outs = self.rknn.inference(inputs=[blob])
  134. return self.postprocess(outs, scale, pad_w, pad_h, h0, w0, conf_threshold_map)
  135. def release(self):
  136. if self.rknn:
  137. self.rknn.release()
  138. self.rknn = None
  139. class ONNXDetector(BaseDetector):
  140. """ONNX 检测器 - 使用 NCHW 输入格式 (1, C, H, W)"""
  141. def __init__(self, model_path: str, label_map: Dict[int, str] = None):
  142. super().__init__(label_map=label_map)
  143. self.model_path = model_path
  144. try:
  145. import onnxruntime as ort
  146. self.session = ort.InferenceSession(model_path)
  147. self.input_name = self.session.get_inputs()[0].name
  148. self.output_name = self.session.get_outputs()[0].name
  149. print(f"ONNX 模型加载成功: {model_path}")
  150. except ImportError:
  151. raise ImportError("未安装 onnxruntime,请运行: pip install onnxruntime")
  152. except Exception as e:
  153. raise RuntimeError(f"加载 ONNX 模型失败: {e}")
  154. def detect(self, image, conf_threshold_map):
  155. canvas, scale, pad_w, pad_h, h0, w0 = self.letterbox(image)
  156. # ONNX 期望 NCHW (1, C, H, W), RGB, 归一化 0-1
  157. img = canvas[..., ::-1].astype(np.float32) / 255.0
  158. img = img.transpose(2, 0, 1)
  159. blob = img[None, ...] # (1, 3, 640, 640)
  160. outs = self.session.run([self.output_name], {self.input_name: blob})
  161. return self.postprocess(outs, scale, pad_w, pad_h, h0, w0, conf_threshold_map)
  162. def release(self):
  163. self.session = None