test_onnx.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. #!/usr/bin/env python3
  2. import cv2
  3. import numpy as np
  4. import onnxruntime as ort
  5. LABEL_MAP = {0: '安全帽', 4: '安全衣', 3: '人'}
  6. def sigmoid(x):
  7. return 1 / (1 + np.exp(-x))
  8. def test_onnx():
  9. model_path = "yolo11m_safety.onnx"
  10. input_size = (640, 640)
  11. image = cv2.imread("b.jpg")
  12. if image is None:
  13. print("无法读取测试图片")
  14. return
  15. h0, w0 = image.shape[:2]
  16. print(f"=== 预处理 ===")
  17. img = cv2.resize(image, input_size)
  18. img = img[..., ::-1].astype(np.float32) / 255.0
  19. img = img.transpose(2, 0, 1)
  20. blob = img[None, ...]
  21. print(f"Input blob shape: {blob.shape}")
  22. print(f"\n=== 加载 ONNX 模型 ===")
  23. session = ort.InferenceSession(model_path)
  24. input_name = session.get_inputs()[0].name
  25. output_name = session.get_outputs()[0].name
  26. print(f"Input name: {input_name}")
  27. print(f"Output name: {output_name}")
  28. print(f"\n=== 推理 ===")
  29. outputs = session.run([output_name], {input_name: blob})
  30. output = outputs[0]
  31. print(f"Output shape: {output.shape}")
  32. print(f"\n=== 原始输出分析 ===")
  33. output_a = output[0]
  34. print(f"After squeeze shape: {output_a.shape}")
  35. print(f"\n=== 查找高置信度框 (obj_conf > 0.1) ===")
  36. high_obj_indices = []
  37. for i in range(output_a.shape[1]):
  38. obj_conf = output_a[4, i]
  39. if obj_conf > 0.1:
  40. high_obj_indices.append((i, obj_conf))
  41. high_obj_indices.sort(key=lambda x: x[1], reverse=True)
  42. print(f"找到 {len(high_obj_indices)} 个高置信度框")
  43. print(f"\n前 20 个高置信度框:")
  44. for idx, obj_conf in high_obj_indices[:20]:
  45. x_center = float(output_a[0, idx])
  46. y_center = float(output_a[1, idx])
  47. width = float(output_a[2, idx])
  48. height = float(output_a[3, idx])
  49. class_probs_raw = output_a[5:9, idx]
  50. class_probs = sigmoid(class_probs_raw)
  51. class_id = int(np.argmax(class_probs))
  52. class_conf = float(class_probs[class_id])
  53. confidence = obj_conf * class_conf
  54. print(f"\nBox {idx}:")
  55. print(f" 坐标: x={x_center:.3f}, y={y_center:.3f}, w={width:.3f}, h={height:.3f}")
  56. print(f" 置信度: obj_conf={obj_conf:.3f}, class_conf={class_conf:.6f}, total={confidence:.3f}")
  57. print(f" 类别: class_id={class_id}, name={LABEL_MAP.get(class_id, 'unknown')}")
  58. print(f"\n=== 检测结果 ===")
  59. dets = []
  60. h0, w0 = image.shape[:2]
  61. for idx, obj_conf in high_obj_indices:
  62. x_center = float(output_a[0, idx])
  63. y_center = float(output_a[1, idx])
  64. width = float(output_a[2, idx])
  65. height = float(output_a[3, idx])
  66. class_probs_raw = output_a[5:9, idx]
  67. class_probs = sigmoid(class_probs_raw)
  68. class_id = int(np.argmax(class_probs))
  69. class_conf = float(class_probs[class_id])
  70. confidence = obj_conf * class_conf
  71. if class_id not in LABEL_MAP:
  72. continue
  73. conf_threshold = 0.01
  74. if confidence < conf_threshold:
  75. continue
  76. x1 = int(x_center - width/2)
  77. y1 = int(y_center - height/2)
  78. x2 = int(x_center + width/2)
  79. y2 = int(y_center + height/2)
  80. x1 = int(x1 * (w0/640))
  81. y1 = int(y1 * (h0/640))
  82. x2 = int(x2 * (w0/640))
  83. y2 = int(y2 * (h0/640))
  84. x1 = max(0, x1)
  85. y1 = max(0, y1)
  86. x2 = min(w0, x2)
  87. y2 = min(h0, y2)
  88. dets.append({
  89. 'class_id': class_id,
  90. 'class_name': LABEL_MAP[class_id],
  91. 'confidence': confidence,
  92. 'bbox': (x1, y1, x2, y2)
  93. })
  94. print(f"\n检测到 {len(dets)} 个框 (阈值 0.01)")
  95. for d in dets[:20]:
  96. print(f" {d['class_name']}: conf={d['confidence']:.3f}, box={d['bbox']}")
  97. if __name__ == "__main__":
  98. test_onnx()