reanalyze_calibration.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. #!/usr/bin/env python3
  2. """
  3. 用多尺度模板匹配重新分析已拍摄的 PTZ 校准图像,生成更可靠的映射表。
  4. """
  5. import os
  6. import sys
  7. import json
  8. import argparse
  9. from pathlib import Path
  10. from datetime import datetime
  11. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  12. import cv2
  13. import numpy as np
  14. def multi_scale_template_match(ptz_img: np.ndarray, panorama_img: np.ndarray,
  15. scales=(0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5)):
  16. """
  17. 在全景图中搜索 PTZ 图像的最佳匹配位置。
  18. 返回 (x_ratio, y_ratio, best_scale, score)
  19. """
  20. if ptz_img is None or panorama_img is None:
  21. return None
  22. ptz_gray = cv2.cvtColor(ptz_img, cv2.COLOR_BGR2GRAY)
  23. pano_gray = cv2.cvtColor(panorama_img, cv2.COLOR_BGR2GRAY)
  24. ph, pw = ptz_gray.shape
  25. pano_h, pano_w = pano_gray.shape
  26. best = None
  27. best_score = -1
  28. for scale in scales:
  29. resized_w = int(pw * scale)
  30. resized_h = int(ph * scale)
  31. if resized_w > pano_w or resized_h > pano_h:
  32. continue
  33. resized = cv2.resize(ptz_gray, (resized_w, resized_h), interpolation=cv2.INTER_AREA)
  34. result = cv2.matchTemplate(pano_gray, resized, cv2.TM_CCOEFF_NORMED)
  35. _, max_val, _, max_loc = cv2.minMaxLoc(result)
  36. if max_val > best_score:
  37. best_score = max_val
  38. center_x = max_loc[0] + resized_w / 2
  39. center_y = max_loc[1] + resized_h / 2
  40. best = (center_x / pano_w, center_y / pano_h, scale, float(max_val))
  41. return best
  42. def main():
  43. parser = argparse.ArgumentParser()
  44. parser.add_argument('--input-dir', type=str, required=True, help='扫描结果目录')
  45. parser.add_argument('--output-dir', type=str, default=None, help='输出目录')
  46. parser.add_argument('--score-threshold', type=float, default=0.45, help='匹配分数阈值')
  47. args = parser.parse_args()
  48. input_dir = Path(args.input_dir)
  49. output_dir = Path(args.output_dir) if args.output_dir else input_dir / 'reanalysis'
  50. output_dir.mkdir(parents=True, exist_ok=True)
  51. ptz_dir = input_dir / 'ptz_images'
  52. panorama_path = input_dir / 'panorama.jpg'
  53. panorama = cv2.imread(str(panorama_path))
  54. if panorama is None:
  55. print(f'无法读取全景图: {panorama_path}')
  56. return
  57. pano_h, pano_w = panorama.shape[:2]
  58. print(f'全景图尺寸: {pano_w}x{pano_h}')
  59. raw_path = input_dir / 'mapping_raw.json'
  60. with open(raw_path, 'r', encoding='utf-8') as f:
  61. raw_data = json.load(f)
  62. records = []
  63. for r in raw_data['records']:
  64. filename = r['filename']
  65. ptz_path = ptz_dir / filename
  66. ptz_img = cv2.imread(str(ptz_path))
  67. if ptz_img is None:
  68. print(f' 无法读取: {ptz_path}')
  69. continue
  70. result = multi_scale_template_match(ptz_img, panorama)
  71. record = dict(r)
  72. if result and result[3] >= args.score_threshold:
  73. x_ratio, y_ratio, scale, score = result
  74. record['tm_x_ratio'] = round(x_ratio, 4)
  75. record['tm_y_ratio'] = round(y_ratio, 4)
  76. record['tm_panorama_x'] = int(x_ratio * pano_w)
  77. record['tm_panorama_y'] = int(y_ratio * pano_h)
  78. record['tm_scale'] = round(scale, 3)
  79. record['tm_score'] = round(score, 3)
  80. print(f"{filename}: pan={r['pan']:3d} tilt={r['tilt']:+3d} -> x={x_ratio:.3f} y={y_ratio:.3f} scale={scale:.2f} score={score:.3f}")
  81. else:
  82. if result:
  83. print(f"{filename}: pan={r['pan']:3d} tilt={r['tilt']:+3d} -> score too low: {result[3]:.3f}")
  84. else:
  85. print(f"{filename}: pan={r['pan']:3d} tilt={r['tilt']:+3d} -> no match")
  86. records.append(record)
  87. # 保存结果
  88. output_path = output_dir / 'mapping_tm.json'
  89. with open(output_path, 'w', encoding='utf-8') as f:
  90. json.dump({
  91. 'created_at': datetime.now().isoformat(),
  92. 'panorama_size': {'width': pano_w, 'height': pano_h},
  93. 'score_threshold': args.score_threshold,
  94. 'records': records,
  95. }, f, indent=2, ensure_ascii=False)
  96. print(f'\n结果已保存: {output_path}')
  97. # 生成查找表
  98. valid = [r for r in records if 'tm_x_ratio' in r]
  99. pan_lookup = sorted([[r['tm_x_ratio'], float(r['pan'])] for r in valid], key=lambda x: x[0])
  100. tilt_lookup = sorted([[r['tm_y_ratio'], float(r['tilt'])] for r in valid], key=lambda x: x[0])
  101. lookup = {
  102. 'created_at': datetime.now().isoformat(),
  103. 'pan_lookup': pan_lookup,
  104. 'tilt_lookup': tilt_lookup,
  105. 'valid_count': len(valid),
  106. }
  107. lookup_path = output_dir / 'lookup_table_tm.json'
  108. with open(lookup_path, 'w', encoding='utf-8') as f:
  109. json.dump(lookup, f, indent=2, ensure_ascii=False)
  110. print(f'查找表已保存: {lookup_path} (有效点 {len(valid)}/{len(records)})')
  111. if __name__ == '__main__':
  112. main()