match_lightglue.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. #!/usr/bin/env python3
  2. """
  3. 使用 LightGlue (SuperPoint) 对 PTZ 校准扫描图片与全景图做深度学习特征匹配,
  4. 在本地生成全景->PTZ 映射表。
  5. """
  6. import os
  7. import json
  8. import argparse
  9. import logging
  10. from pathlib import Path
  11. from typing import Tuple, Optional, Dict, Any, List
  12. import torch
  13. import cv2
  14. import numpy as np
  15. from lightglue import LightGlue, SuperPoint, match_pair
  16. from lightglue.utils import rbd
  17. logging.basicConfig(
  18. level=logging.INFO,
  19. format='%(asctime)s - %(levelname)s - %(message)s'
  20. )
  21. logger = logging.getLogger(__name__)
  22. def load_tensor(bgr_img: np.ndarray, device: torch.device) -> torch.Tensor:
  23. """BGR numpy -> RGB float tensor (C,H,W)"""
  24. rgb = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
  25. t = torch.from_numpy(rgb).permute(2, 0, 1).float() / 255.0
  26. return t.to(device)
  27. def resize_max_side(img: np.ndarray, max_side: int) -> Tuple[np.ndarray, float]:
  28. """保持宽高比缩放,使长边不超过 max_side"""
  29. h, w = img.shape[:2]
  30. scale = min(max_side / max(h, w), 1.0)
  31. if scale == 1.0:
  32. return img.copy(), 1.0
  33. new_w, new_h = int(round(w * scale)), int(round(h * scale))
  34. return cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA), scale
  35. def center_crop(img: np.ndarray, ratio: float) -> np.ndarray:
  36. """按中心裁剪 ratio 比例的区域"""
  37. h, w = img.shape[:2]
  38. new_w, new_h = int(w * ratio), int(h * ratio)
  39. x0, y0 = (w - new_w) // 2, (h - new_h) // 2
  40. return img[y0:y0 + new_h, x0:x0 + new_w]
  41. def median_point_from_matches(
  42. mkpts_panorama: np.ndarray,
  43. mkpts_ptz: np.ndarray,
  44. pano_scale: float,
  45. ptz_scale: float,
  46. ptz_offset: Tuple[int, int],
  47. pano_h: int,
  48. pano_w: int,
  49. ) -> Optional[Tuple[float, float, int, float, float]]:
  50. """
  51. 根据匹配点计算 PTZ 中心在全景图中的对应位置。
  52. 返回 (x_ratio, y_ratio, num_matches, median_confidence, spread)
  53. """
  54. if len(mkpts_panorama) == 0:
  55. return None
  56. # 把缩放后的坐标还原到原图
  57. mkpts_panorama_orig = mkpts_panorama / pano_scale
  58. mkpts_ptz_orig = (mkpts_ptz / ptz_scale) + np.array(ptz_offset)
  59. # 取全景坐标中位数作为 PTZ 视场中心投影
  60. med_x = float(np.median(mkpts_panorama_orig[:, 0]))
  61. med_y = float(np.median(mkpts_panorama_orig[:, 1]))
  62. x_ratio = np.clip(med_x / pano_w, 0.0, 1.0)
  63. y_ratio = np.clip(med_y / pano_h, 0.0, 1.0)
  64. spread = float(np.median(np.linalg.norm(mkpts_panorama_orig - np.array([med_x, med_y]), axis=1)))
  65. return x_ratio, y_ratio, len(mkpts_panorama), 0.0, spread
  66. def build_models(device: torch.device):
  67. extractor = SuperPoint(max_num_keypoints=2048).eval().to(device)
  68. matcher = LightGlue(
  69. features='superpoint',
  70. depth_confidence=-1,
  71. width_confidence=-1,
  72. ).eval().to(device)
  73. return extractor, matcher
  74. def main():
  75. parser = argparse.ArgumentParser(description='LightGlue PTZ-panorama matching')
  76. parser.add_argument('--base', type=str,
  77. default='/Users/wenhongquan/Desktop/阿里云同步/项目/dnn/德胜河 AI/dsh/calibration_scan_180_360_z1',
  78. help='扫描结果目录')
  79. parser.add_argument('--pano-max-side', type=int, default=2048,
  80. help='全景图缩放后的长边像素')
  81. parser.add_argument('--ptz-max-side', type=int, default=1280,
  82. help='PTZ 图缩放后的长边像素')
  83. parser.add_argument('--ptz-crop-ratio', type=float, default=0.6,
  84. help='仅使用 PTZ 中心区域进行匹配')
  85. parser.add_argument('--min-matches', type=int, default=20,
  86. help='最少匹配点数才认为有效')
  87. parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'mps', 'cuda'],
  88. help='计算设备')
  89. parser.add_argument('--output-suffix', type=str, default='lightglue',
  90. help='输出文件后缀')
  91. args = parser.parse_args()
  92. base = Path(args.base)
  93. ptz_dir = base / 'ptz_images'
  94. pano_path = base / 'panorama.jpg'
  95. raw_path = base / 'mapping_raw.json'
  96. device = torch.device(args.device)
  97. logger.info(f'Using device: {device}')
  98. logger.info('Loading models...')
  99. extractor, matcher = build_models(device)
  100. logger.info(f'Loading panorama: {pano_path}')
  101. panorama_orig = cv2.imread(str(pano_path))
  102. if panorama_orig is None:
  103. logger.error('Failed to load panorama')
  104. return
  105. pano_h, pano_w = panorama_orig.shape[:2]
  106. panorama_small, pano_scale = resize_max_side(panorama_orig, args.pano_max_side)
  107. logger.info(f'Panorama resized: {panorama_small.shape[1]}x{panorama_small.shape[0]} (scale={pano_scale:.3f})')
  108. logger.info('Extracting panorama features...')
  109. pano_tensor = load_tensor(panorama_small, device)[None]
  110. feats_pano = extractor({'image': pano_tensor})
  111. logger.info(f"Panorama keypoints: {feats_pano['keypoints'].shape[1]}")
  112. logger.info(f'Loading records from {raw_path}')
  113. with open(raw_path, 'r', encoding='utf-8') as f:
  114. raw_data = json.load(f)
  115. records = raw_data['records']
  116. results: List[Dict[str, Any]] = []
  117. summary_points = []
  118. for idx, rec in enumerate(records, 1):
  119. filename = rec['filename']
  120. pan = rec['pan']
  121. tilt = rec['tilt']
  122. ptz_path = ptz_dir / filename
  123. ptz_orig = cv2.imread(str(ptz_path))
  124. if ptz_orig is None:
  125. logger.warning(f'[{idx}/{len(records)}] Cannot read {filename}')
  126. rec_copy = dict(rec)
  127. rec_copy['matched'] = False
  128. results.append(rec_copy)
  129. continue
  130. # 中心裁剪 + 缩放
  131. ptz_crop = center_crop(ptz_orig, args.ptz_crop_ratio)
  132. ptz_offset = ((ptz_orig.shape[1] - ptz_crop.shape[1]) // 2,
  133. (ptz_orig.shape[0] - ptz_crop.shape[0]) // 2)
  134. ptz_small, ptz_scale = resize_max_side(ptz_crop, args.ptz_max_side)
  135. ptz_tensor = load_tensor(ptz_small, device)[None]
  136. feats_ptz = extractor({'image': ptz_tensor})
  137. matches = matcher({
  138. 'image0': feats_ptz,
  139. 'image1': feats_pano,
  140. })
  141. feats_ptz = rbd(feats_ptz)
  142. feats_pano_single = rbd(feats_pano)
  143. matches = rbd(matches)
  144. m = matches['matches']
  145. num_matches = len(m)
  146. rec_copy = dict(rec)
  147. rec_copy['lg_matches'] = num_matches
  148. if num_matches >= args.min_matches:
  149. kpts_ptz = feats_ptz['keypoints'][m[..., 0]].cpu().numpy()
  150. kpts_pano = feats_pano_single['keypoints'][m[..., 1]].cpu().numpy()
  151. med = median_point_from_matches(
  152. kpts_pano, kpts_ptz,
  153. pano_scale, ptz_scale, ptz_offset,
  154. pano_h, pano_w,
  155. )
  156. if med:
  157. x_ratio, y_ratio, n, conf, spread = med
  158. rec_copy['lg_matched'] = True
  159. rec_copy['lg_x_ratio'] = round(x_ratio, 4)
  160. rec_copy['lg_y_ratio'] = round(y_ratio, 4)
  161. rec_copy['lg_panorama_x'] = int(x_ratio * pano_w)
  162. rec_copy['lg_panorama_y'] = int(y_ratio * pano_h)
  163. rec_copy['lg_spread'] = round(spread, 2)
  164. summary_points.append((pan, tilt, int(x_ratio * pano_w), int(y_ratio * pano_h)))
  165. logger.info(
  166. f'[{idx}/{len(records)}] {filename} pan={pan} tilt={tilt:+3d} -> '
  167. f'({x_ratio:.3f}, {y_ratio:.3f}) matches={num_matches} spread={spread:.1f}'
  168. )
  169. else:
  170. rec_copy['lg_matched'] = False
  171. logger.info(f'[{idx}/{len(records)}] {filename} matches={num_matches} but no median')
  172. else:
  173. rec_copy['lg_matched'] = False
  174. logger.info(f'[{idx}/{len(records)}] {filename} matches={num_matches} < {args.min_matches}')
  175. results.append(rec_copy)
  176. # 保存详细结果
  177. out_json = base / f'mapping_{args.output_suffix}.json'
  178. with open(out_json, 'w', encoding='utf-8') as f:
  179. json.dump({
  180. 'records': results,
  181. 'panorama_size': {'width': pano_w, 'height': pano_h},
  182. 'params': vars(args),
  183. }, f, indent=2, ensure_ascii=False)
  184. logger.info(f'Saved detailed mapping: {out_json}')
  185. # 生成 lookup table(仅有效点)
  186. valid = [r for r in results if r.get('lg_matched')]
  187. pan_lookup = sorted([[r['lg_x_ratio'], float(r['pan'])] for r in valid], key=lambda x: x[0])
  188. tilt_lookup = sorted([[r['lg_y_ratio'], float(r['tilt'])] for r in valid], key=lambda x: x[0])
  189. lookup = {
  190. 'created_at': raw_data.get('created_at'),
  191. 'pan_lookup': pan_lookup,
  192. 'tilt_lookup': tilt_lookup,
  193. 'valid_count': len(valid),
  194. }
  195. lookup_path = base / f'lookup_table_{args.output_suffix}.json'
  196. with open(lookup_path, 'w', encoding='utf-8') as f:
  197. json.dump(lookup, f, indent=2, ensure_ascii=False)
  198. logger.info(f'Saved lookup table: {lookup_path} ({len(valid)}/{len(records)} valid)')
  199. # 可视化概览:在全景图上标记每个有效位置
  200. if summary_points:
  201. vis = panorama_orig.copy()
  202. for pan, tilt, cx, cy in summary_points:
  203. cv2.circle(vis, (cx, cy), 12, (0, 0, 255), -1)
  204. label = f'{pan},{tilt}'
  205. cv2.putText(vis, label, (cx + 10, cy), cv2.FONT_HERSHEY_SIMPLEX,
  206. 0.5, (0, 255, 0), 1)
  207. vis_path = base / f'panorama_{args.output_suffix}_matches.jpg'
  208. cv2.imwrite(str(vis_path), vis, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
  209. logger.info(f'Saved overview: {vis_path}')
  210. if __name__ == '__main__':
  211. main()