|
|
@@ -0,0 +1,253 @@
|
|
|
+#!/usr/bin/env python3
|
|
|
+"""
|
|
|
+使用 LightGlue (SuperPoint) 对 PTZ 校准扫描图片与全景图做深度学习特征匹配,
|
|
|
+在本地生成全景->PTZ 映射表。
|
|
|
+"""
|
|
|
+import os
|
|
|
+import json
|
|
|
+import argparse
|
|
|
+import logging
|
|
|
+from pathlib import Path
|
|
|
+from typing import Tuple, Optional, Dict, Any, List
|
|
|
+
|
|
|
+import torch
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+from lightglue import LightGlue, SuperPoint, match_pair
|
|
|
+from lightglue.utils import rbd
|
|
|
+
|
|
|
+logging.basicConfig(
|
|
|
+ level=logging.INFO,
|
|
|
+ format='%(asctime)s - %(levelname)s - %(message)s'
|
|
|
+)
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+def load_tensor(bgr_img: np.ndarray, device: torch.device) -> torch.Tensor:
|
|
|
+ """BGR numpy -> RGB float tensor (C,H,W)"""
|
|
|
+ rgb = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
|
|
|
+ t = torch.from_numpy(rgb).permute(2, 0, 1).float() / 255.0
|
|
|
+ return t.to(device)
|
|
|
+
|
|
|
+
|
|
|
+def resize_max_side(img: np.ndarray, max_side: int) -> Tuple[np.ndarray, float]:
|
|
|
+ """保持宽高比缩放,使长边不超过 max_side"""
|
|
|
+ h, w = img.shape[:2]
|
|
|
+ scale = min(max_side / max(h, w), 1.0)
|
|
|
+ if scale == 1.0:
|
|
|
+ return img.copy(), 1.0
|
|
|
+ new_w, new_h = int(round(w * scale)), int(round(h * scale))
|
|
|
+ return cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA), scale
|
|
|
+
|
|
|
+
|
|
|
+def center_crop(img: np.ndarray, ratio: float) -> np.ndarray:
|
|
|
+ """按中心裁剪 ratio 比例的区域"""
|
|
|
+ h, w = img.shape[:2]
|
|
|
+ new_w, new_h = int(w * ratio), int(h * ratio)
|
|
|
+ x0, y0 = (w - new_w) // 2, (h - new_h) // 2
|
|
|
+ return img[y0:y0 + new_h, x0:x0 + new_w]
|
|
|
+
|
|
|
+
|
|
|
+def median_point_from_matches(
|
|
|
+ mkpts_panorama: np.ndarray,
|
|
|
+ mkpts_ptz: np.ndarray,
|
|
|
+ pano_scale: float,
|
|
|
+ ptz_scale: float,
|
|
|
+ ptz_offset: Tuple[int, int],
|
|
|
+ pano_h: int,
|
|
|
+ pano_w: int,
|
|
|
+) -> Optional[Tuple[float, float, int, float, float]]:
|
|
|
+ """
|
|
|
+ 根据匹配点计算 PTZ 中心在全景图中的对应位置。
|
|
|
+ 返回 (x_ratio, y_ratio, num_matches, median_confidence, spread)
|
|
|
+ """
|
|
|
+ if len(mkpts_panorama) == 0:
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 把缩放后的坐标还原到原图
|
|
|
+ mkpts_panorama_orig = mkpts_panorama / pano_scale
|
|
|
+ mkpts_ptz_orig = (mkpts_ptz / ptz_scale) + np.array(ptz_offset)
|
|
|
+
|
|
|
+ # 取全景坐标中位数作为 PTZ 视场中心投影
|
|
|
+ med_x = float(np.median(mkpts_panorama_orig[:, 0]))
|
|
|
+ med_y = float(np.median(mkpts_panorama_orig[:, 1]))
|
|
|
+ x_ratio = np.clip(med_x / pano_w, 0.0, 1.0)
|
|
|
+ y_ratio = np.clip(med_y / pano_h, 0.0, 1.0)
|
|
|
+
|
|
|
+ spread = float(np.median(np.linalg.norm(mkpts_panorama_orig - np.array([med_x, med_y]), axis=1)))
|
|
|
+
|
|
|
+ return x_ratio, y_ratio, len(mkpts_panorama), 0.0, spread
|
|
|
+
|
|
|
+
|
|
|
+def build_models(device: torch.device):
|
|
|
+ extractor = SuperPoint(max_num_keypoints=2048).eval().to(device)
|
|
|
+ matcher = LightGlue(
|
|
|
+ features='superpoint',
|
|
|
+ depth_confidence=-1,
|
|
|
+ width_confidence=-1,
|
|
|
+ ).eval().to(device)
|
|
|
+ return extractor, matcher
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ parser = argparse.ArgumentParser(description='LightGlue PTZ-panorama matching')
|
|
|
+ parser.add_argument('--base', type=str,
|
|
|
+ default='/Users/wenhongquan/Desktop/阿里云同步/项目/dnn/德胜河 AI/dsh/calibration_scan_180_360_z1',
|
|
|
+ help='扫描结果目录')
|
|
|
+ parser.add_argument('--pano-max-side', type=int, default=2048,
|
|
|
+ help='全景图缩放后的长边像素')
|
|
|
+ parser.add_argument('--ptz-max-side', type=int, default=1280,
|
|
|
+ help='PTZ 图缩放后的长边像素')
|
|
|
+ parser.add_argument('--ptz-crop-ratio', type=float, default=0.6,
|
|
|
+ help='仅使用 PTZ 中心区域进行匹配')
|
|
|
+ parser.add_argument('--min-matches', type=int, default=20,
|
|
|
+ help='最少匹配点数才认为有效')
|
|
|
+ parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'mps', 'cuda'],
|
|
|
+ help='计算设备')
|
|
|
+ parser.add_argument('--output-suffix', type=str, default='lightglue',
|
|
|
+ help='输出文件后缀')
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ base = Path(args.base)
|
|
|
+ ptz_dir = base / 'ptz_images'
|
|
|
+ pano_path = base / 'panorama.jpg'
|
|
|
+ raw_path = base / 'mapping_raw.json'
|
|
|
+
|
|
|
+ device = torch.device(args.device)
|
|
|
+ logger.info(f'Using device: {device}')
|
|
|
+
|
|
|
+ logger.info('Loading models...')
|
|
|
+ extractor, matcher = build_models(device)
|
|
|
+
|
|
|
+ logger.info(f'Loading panorama: {pano_path}')
|
|
|
+ panorama_orig = cv2.imread(str(pano_path))
|
|
|
+ if panorama_orig is None:
|
|
|
+ logger.error('Failed to load panorama')
|
|
|
+ return
|
|
|
+ pano_h, pano_w = panorama_orig.shape[:2]
|
|
|
+ panorama_small, pano_scale = resize_max_side(panorama_orig, args.pano_max_side)
|
|
|
+ logger.info(f'Panorama resized: {panorama_small.shape[1]}x{panorama_small.shape[0]} (scale={pano_scale:.3f})')
|
|
|
+
|
|
|
+ logger.info('Extracting panorama features...')
|
|
|
+ pano_tensor = load_tensor(panorama_small, device)[None]
|
|
|
+ feats_pano = extractor({'image': pano_tensor})
|
|
|
+ logger.info(f"Panorama keypoints: {feats_pano['keypoints'].shape[1]}")
|
|
|
+
|
|
|
+ logger.info(f'Loading records from {raw_path}')
|
|
|
+ with open(raw_path, 'r', encoding='utf-8') as f:
|
|
|
+ raw_data = json.load(f)
|
|
|
+ records = raw_data['records']
|
|
|
+
|
|
|
+ results: List[Dict[str, Any]] = []
|
|
|
+ summary_points = []
|
|
|
+
|
|
|
+ for idx, rec in enumerate(records, 1):
|
|
|
+ filename = rec['filename']
|
|
|
+ pan = rec['pan']
|
|
|
+ tilt = rec['tilt']
|
|
|
+ ptz_path = ptz_dir / filename
|
|
|
+
|
|
|
+ ptz_orig = cv2.imread(str(ptz_path))
|
|
|
+ if ptz_orig is None:
|
|
|
+ logger.warning(f'[{idx}/{len(records)}] Cannot read {filename}')
|
|
|
+ rec_copy = dict(rec)
|
|
|
+ rec_copy['matched'] = False
|
|
|
+ results.append(rec_copy)
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 中心裁剪 + 缩放
|
|
|
+ ptz_crop = center_crop(ptz_orig, args.ptz_crop_ratio)
|
|
|
+ ptz_offset = ((ptz_orig.shape[1] - ptz_crop.shape[1]) // 2,
|
|
|
+ (ptz_orig.shape[0] - ptz_crop.shape[0]) // 2)
|
|
|
+ ptz_small, ptz_scale = resize_max_side(ptz_crop, args.ptz_max_side)
|
|
|
+
|
|
|
+ ptz_tensor = load_tensor(ptz_small, device)[None]
|
|
|
+
|
|
|
+ feats_ptz = extractor({'image': ptz_tensor})
|
|
|
+ matches = matcher({
|
|
|
+ 'image0': feats_ptz,
|
|
|
+ 'image1': feats_pano,
|
|
|
+ })
|
|
|
+
|
|
|
+ feats_ptz = rbd(feats_ptz)
|
|
|
+ feats_pano_single = rbd(feats_pano)
|
|
|
+ matches = rbd(matches)
|
|
|
+
|
|
|
+ m = matches['matches']
|
|
|
+ num_matches = len(m)
|
|
|
+
|
|
|
+ rec_copy = dict(rec)
|
|
|
+ rec_copy['lg_matches'] = num_matches
|
|
|
+
|
|
|
+ if num_matches >= args.min_matches:
|
|
|
+ kpts_ptz = feats_ptz['keypoints'][m[..., 0]].cpu().numpy()
|
|
|
+ kpts_pano = feats_pano_single['keypoints'][m[..., 1]].cpu().numpy()
|
|
|
+
|
|
|
+ med = median_point_from_matches(
|
|
|
+ kpts_pano, kpts_ptz,
|
|
|
+ pano_scale, ptz_scale, ptz_offset,
|
|
|
+ pano_h, pano_w,
|
|
|
+ )
|
|
|
+ if med:
|
|
|
+ x_ratio, y_ratio, n, conf, spread = med
|
|
|
+ rec_copy['lg_matched'] = True
|
|
|
+ rec_copy['lg_x_ratio'] = round(x_ratio, 4)
|
|
|
+ rec_copy['lg_y_ratio'] = round(y_ratio, 4)
|
|
|
+ rec_copy['lg_panorama_x'] = int(x_ratio * pano_w)
|
|
|
+ rec_copy['lg_panorama_y'] = int(y_ratio * pano_h)
|
|
|
+ rec_copy['lg_spread'] = round(spread, 2)
|
|
|
+ summary_points.append((pan, tilt, int(x_ratio * pano_w), int(y_ratio * pano_h)))
|
|
|
+ logger.info(
|
|
|
+ f'[{idx}/{len(records)}] {filename} pan={pan} tilt={tilt:+3d} -> '
|
|
|
+ f'({x_ratio:.3f}, {y_ratio:.3f}) matches={num_matches} spread={spread:.1f}'
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ rec_copy['lg_matched'] = False
|
|
|
+ logger.info(f'[{idx}/{len(records)}] {filename} matches={num_matches} but no median')
|
|
|
+ else:
|
|
|
+ rec_copy['lg_matched'] = False
|
|
|
+ logger.info(f'[{idx}/{len(records)}] {filename} matches={num_matches} < {args.min_matches}')
|
|
|
+
|
|
|
+ results.append(rec_copy)
|
|
|
+
|
|
|
+ # 保存详细结果
|
|
|
+ out_json = base / f'mapping_{args.output_suffix}.json'
|
|
|
+ with open(out_json, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump({
|
|
|
+ 'records': results,
|
|
|
+ 'panorama_size': {'width': pano_w, 'height': pano_h},
|
|
|
+ 'params': vars(args),
|
|
|
+ }, f, indent=2, ensure_ascii=False)
|
|
|
+ logger.info(f'Saved detailed mapping: {out_json}')
|
|
|
+
|
|
|
+ # 生成 lookup table(仅有效点)
|
|
|
+ valid = [r for r in results if r.get('lg_matched')]
|
|
|
+ pan_lookup = sorted([[r['lg_x_ratio'], float(r['pan'])] for r in valid], key=lambda x: x[0])
|
|
|
+ tilt_lookup = sorted([[r['lg_y_ratio'], float(r['tilt'])] for r in valid], key=lambda x: x[0])
|
|
|
+ lookup = {
|
|
|
+ 'created_at': raw_data.get('created_at'),
|
|
|
+ 'pan_lookup': pan_lookup,
|
|
|
+ 'tilt_lookup': tilt_lookup,
|
|
|
+ 'valid_count': len(valid),
|
|
|
+ }
|
|
|
+ lookup_path = base / f'lookup_table_{args.output_suffix}.json'
|
|
|
+ with open(lookup_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(lookup, f, indent=2, ensure_ascii=False)
|
|
|
+ logger.info(f'Saved lookup table: {lookup_path} ({len(valid)}/{len(records)} valid)')
|
|
|
+
|
|
|
+ # 可视化概览:在全景图上标记每个有效位置
|
|
|
+ if summary_points:
|
|
|
+ vis = panorama_orig.copy()
|
|
|
+ for pan, tilt, cx, cy in summary_points:
|
|
|
+ cv2.circle(vis, (cx, cy), 12, (0, 0, 255), -1)
|
|
|
+ label = f'{pan},{tilt}'
|
|
|
+ cv2.putText(vis, label, (cx + 10, cy), cv2.FONT_HERSHEY_SIMPLEX,
|
|
|
+ 0.5, (0, 255, 0), 1)
|
|
|
+ vis_path = base / f'panorama_{args.output_suffix}_matches.jpg'
|
|
|
+ cv2.imwrite(str(vis_path), vis, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
|
|
|
+ logger.info(f'Saved overview: {vis_path}')
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ main()
|