test_model.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. #!/usr/bin/env python3
  2. import cv2
  3. import numpy as np
  4. from rknnlite.api import RKNNLite
  5. from dataclasses import dataclass
  6. from typing import List, Tuple, Optional
  7. @dataclass
  8. class Detection:
  9. class_id: int
  10. class_name: str
  11. confidence: float
  12. bbox: Tuple[int, int, int, int]
  13. LABEL_MAP = {0: '安全帽', 4: '安全衣', 3: '人'}
  14. INPUT_SIZE = (640, 640)
  15. def nms(dets, iou_threshold=0.45):
  16. if len(dets) == 0:
  17. return []
  18. boxes = np.array([[d.bbox[0], d.bbox[1], d.bbox[2], d.bbox[3], d.confidence] for d in dets])
  19. x1 = boxes[:, 0]
  20. y1 = boxes[:, 1]
  21. x2 = boxes[:, 2]
  22. y2 = boxes[:, 3]
  23. scores = boxes[:, 4]
  24. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  25. order = scores.argsort()[::-1]
  26. keep = []
  27. while order.size > 0:
  28. i = order[0]
  29. keep.append(i)
  30. xx1 = np.maximum(x1[i], x1[order[1:]])
  31. yy1 = np.maximum(y1[i], y1[order[1:]])
  32. xx2 = np.minimum(x2[i], x2[order[1:]])
  33. yy2 = np.minimum(y2[i], y2[order[1:]])
  34. w = np.maximum(0.0, xx2 - xx1 + 1)
  35. h = np.maximum(0.0, yy2 - yy1 + 1)
  36. inter = w * h
  37. ovr = inter / (areas[i] + areas[order[1:]] - inter)
  38. inds = np.where(ovr <= iou_threshold)[0]
  39. order = order[inds + 1]
  40. return [dets[i] for i in keep]
  41. def letterbox(image, input_size=(640, 640)):
  42. h0, w0 = image.shape[:2]
  43. ih, iw = input_size
  44. scale = min(iw / w0, ih / h0)
  45. new_w, new_h = int(w0 * scale), int(h0 * scale)
  46. pad_w = (iw - new_w) // 2
  47. pad_h = (ih - new_h) // 2
  48. resized = cv2.resize(image, (new_w, new_h))
  49. canvas = np.full((ih, iw, 3), 114, dtype=np.uint8)
  50. canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized
  51. return canvas, scale, pad_w, pad_h, h0, w0
  52. def test_model():
  53. model_path = "yolo11m_safety.rknn"
  54. conf_threshold_map = {3: 0.8, 0: 0.5, 4: 0.5}
  55. rknn = RKNNLite()
  56. ret = rknn.load_rknn(model_path)
  57. if ret != 0:
  58. print("[ERROR] load_rknn failed")
  59. return
  60. ret = rknn.init_runtime(core_mask=RKNNLite.NPU_CORE_0_1_2)
  61. if ret != 0:
  62. print("[ERROR] init_runtime failed")
  63. return
  64. image = cv2.imread("b.jpg")
  65. if image is None:
  66. print("无法读取测试图片")
  67. return
  68. canvas, scale, pad_w, pad_h, h0, w0 = letterbox(image)
  69. # RKNN expects NHWC input: (1, H, W, C), RGB, float32 normalized 0-1
  70. img = canvas[..., ::-1].astype(np.float32) / 255.0
  71. blob = img[None, ...] # (1, 640, 640, 3)
  72. outputs = rknn.inference(inputs=[blob])
  73. if outputs:
  74. output = outputs[0]
  75. if len(output.shape) == 3:
  76. output = output[0]
  77. num_classes = 5
  78. dets = []
  79. for i in range(output.shape[1]):
  80. x_center = float(output[0, i])
  81. y_center = float(output[1, i])
  82. width = float(output[2, i])
  83. height = float(output[3, i])
  84. class_probs = output[4:4+num_classes, i]
  85. best_class = int(np.argmax(class_probs))
  86. confidence = float(class_probs[best_class])
  87. if best_class not in LABEL_MAP:
  88. continue
  89. conf_threshold = conf_threshold_map.get(best_class, 0.5)
  90. if confidence < conf_threshold:
  91. continue
  92. # Remove padding and scale to original image
  93. x1 = int(((x_center - width / 2) - pad_w) / scale)
  94. y1 = int(((y_center - height / 2) - pad_h) / scale)
  95. x2 = int(((x_center + width / 2) - pad_w) / scale)
  96. y2 = int(((y_center + height / 2) - pad_h) / scale)
  97. x1 = max(0, min(w0, x1))
  98. y1 = max(0, min(h0, y1))
  99. x2 = max(0, min(w0, x2))
  100. y2 = max(0, min(h0, y2))
  101. det = Detection(
  102. class_id=best_class,
  103. class_name=LABEL_MAP[best_class],
  104. confidence=confidence,
  105. bbox=(x1, y1, x2, y2)
  106. )
  107. dets.append(det)
  108. dets = nms(dets, iou_threshold=0.45)
  109. print(f"检测结果: {len(dets)} 个目标")
  110. for d in dets:
  111. print(f" {d.class_name}: conf={d.confidence:.3f}, box={d.bbox}")
  112. rknn.release()
  113. if __name__ == "__main__":
  114. test_model()