re_match_orb.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. #!/usr/bin/env python3
  2. """
  3. 用 ORB 重新匹配已扫描的球机图与全景图,生成更可靠的映射表。
  4. 用法:
  5. python scripts/re_match_orb.py --scan-dir /home/admin/dsh/calibration_scan
  6. """
  7. import os
  8. import sys
  9. import json
  10. import argparse
  11. import logging
  12. from pathlib import Path
  13. from typing import List, Tuple, Optional
  14. import cv2
  15. import numpy as np
  16. logging.basicConfig(
  17. level=logging.INFO,
  18. format='%(asctime)s - %(levelname)s - %(message)s'
  19. )
  20. logger = logging.getLogger(__name__)
  21. def match_orb(
  22. ptz_img: np.ndarray,
  23. panorama_img: np.ndarray,
  24. min_matches: int = 8,
  25. ratio_thresh: float = 0.8,
  26. ) -> Tuple[Optional[Tuple[float, float]], Optional[np.ndarray]]:
  27. """ORB 特征匹配"""
  28. if ptz_img is None or panorama_img is None:
  29. return None, None
  30. gray_p = cv2.cvtColor(ptz_img, cv2.COLOR_BGR2GRAY)
  31. gray_g = cv2.cvtColor(panorama_img, cv2.COLOR_BGR2GRAY)
  32. orb = cv2.ORB_create(nfeatures=1000)
  33. kp_p, des_p = orb.detectAndCompute(gray_p, None)
  34. kp_g, des_g = orb.detectAndCompute(gray_g, None)
  35. if des_p is None or des_g is None or len(kp_p) < 4 or len(kp_g) < 4:
  36. return None, None
  37. bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
  38. matches = bf.knnMatch(des_p, des_g, k=2)
  39. good = []
  40. for m_n in matches:
  41. if len(m_n) == 2:
  42. m, n = m_n
  43. if m.distance < ratio_thresh * n.distance:
  44. good.append(m)
  45. if len(good) < min_matches:
  46. return None, None
  47. pts_p = np.float32([kp_p[m.queryIdx].pt for m in good])
  48. pts_g = np.float32([kp_g[m.trainIdx].pt for m in good])
  49. H, mask = cv2.findHomography(pts_p, pts_g, cv2.RANSAC, 5.0)
  50. if mask is None:
  51. return None, None
  52. inlier_g = pts_g[mask.ravel() == 1]
  53. if len(inlier_g) < min_matches:
  54. return None, None
  55. center_x = float(np.median(inlier_g[:, 0]))
  56. center_y = float(np.median(inlier_g[:, 1]))
  57. h, w = panorama_img.shape[:2]
  58. x_ratio = np.clip(center_x / w, 0.0, 1.0)
  59. y_ratio = np.clip(center_y / h, 0.0, 1.0)
  60. vis = cv2.drawMatches(
  61. ptz_img, kp_p, panorama_img, kp_g,
  62. [good[i] for i in range(len(good)) if mask[i]],
  63. None, flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS
  64. )
  65. return (x_ratio, y_ratio), vis
  66. def run_rematch(scan_dir: Path):
  67. ptz_dir = scan_dir / 'ptz_images'
  68. pano_path = scan_dir / 'panorama.jpg'
  69. out_dir = scan_dir / 'rematch_orb'
  70. out_dir.mkdir(exist_ok=True)
  71. panorama = cv2.imread(str(pano_path))
  72. if panorama is None:
  73. logger.error(f'无法读取全景图: {pano_path}')
  74. return
  75. h, w = panorama.shape[:2]
  76. records = []
  77. for img_path in sorted(ptz_dir.glob('ptz_p*_t*.jpg')):
  78. # 解析文件名中的 pan/tilt
  79. stem = img_path.stem # e.g. ptz_p080_t-05
  80. parts = stem.replace('ptz_p', '').split('_t')
  81. pan = int(parts[0])
  82. tilt = int(parts[1])
  83. ptz_img = cv2.imread(str(img_path))
  84. if ptz_img is None:
  85. continue
  86. pos, vis = match_orb(ptz_img, panorama)
  87. record = {
  88. 'filename': img_path.name,
  89. 'pan': pan,
  90. 'tilt': tilt,
  91. 'matched': pos is not None,
  92. }
  93. if pos:
  94. record['x_ratio'] = round(pos[0], 4)
  95. record['y_ratio'] = round(pos[1], 4)
  96. record['panorama_x'] = int(pos[0] * w)
  97. record['panorama_y'] = int(pos[1] * h)
  98. vis_path = out_dir / f'match_{img_path.name}'
  99. cv2.imwrite(str(vis_path), vis, [int(cv2.IMWRITE_JPEG_QUALITY), 85])
  100. logger.info(f'{img_path.name} -> ({pos[0]:.3f}, {pos[1]:.3f})')
  101. else:
  102. logger.info(f'{img_path.name} -> 未匹配')
  103. records.append(record)
  104. # 保存映射表
  105. mapping_path = scan_dir / 'mapping_raw_orb.json'
  106. with open(mapping_path, 'w', encoding='utf-8') as f:
  107. json.dump({
  108. 'method': 'ORB',
  109. 'panorama_size': {'width': w, 'height': h},
  110. 'records': records,
  111. }, f, indent=2, ensure_ascii=False)
  112. logger.info(f'原始映射已保存: {mapping_path}')
  113. # 生成查找表
  114. valid = [r for r in records if r.get('matched')]
  115. pan_lookup = sorted([[r['x_ratio'], float(r['pan'])] for r in valid], key=lambda x: x[0])
  116. tilt_lookup = sorted([[r['y_ratio'], float(r['tilt'])] for r in valid], key=lambda x: x[0])
  117. lookup_path = scan_dir / 'lookup_table_orb.json'
  118. with open(lookup_path, 'w', encoding='utf-8') as f:
  119. json.dump({
  120. 'method': 'ORB',
  121. 'pan_lookup': pan_lookup,
  122. 'tilt_lookup': tilt_lookup,
  123. 'valid_count': len(valid),
  124. }, f, indent=2, ensure_ascii=False)
  125. logger.info(f'ORB 查找表已保存: {lookup_path} (有效 {len(valid)}/{len(records)})')
  126. # 更新 CSV
  127. csv_path = scan_dir / 'mapping_for_review_orb.csv'
  128. with open(csv_path, 'w', encoding='utf-8') as f:
  129. f.write('filename,pan,tilt,x_ratio,y_ratio,panorama_x,panorama_y,matched,review_x,review_y\n')
  130. for r in records:
  131. f.write(
  132. f"{r['filename']},{r['pan']},{r['tilt']},"
  133. f"{r.get('x_ratio', '')},{r.get('y_ratio', '')},"
  134. f"{r.get('panorama_x', '')},{r.get('panorama_y', '')},"
  135. f"{r['matched']},,\n"
  136. )
  137. logger.info(f'复核 CSV 已保存: {csv_path}')
  138. def main():
  139. parser = argparse.ArgumentParser()
  140. parser.add_argument('--scan-dir', type=str, default='/home/admin/dsh/calibration_scan',
  141. help='扫描结果目录')
  142. args = parser.parse_args()
  143. run_rematch(Path(args.scan_dir))
  144. if __name__ == '__main__':
  145. main()