| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253 |
- #!/usr/bin/env python3
- """
- 使用 LightGlue (SuperPoint) 对 PTZ 校准扫描图片与全景图做深度学习特征匹配,
- 在本地生成全景->PTZ 映射表。
- """
- import os
- import json
- import argparse
- import logging
- from pathlib import Path
- from typing import Tuple, Optional, Dict, Any, List
- import torch
- import cv2
- import numpy as np
- from lightglue import LightGlue, SuperPoint, match_pair
- from lightglue.utils import rbd
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(levelname)s - %(message)s'
- )
- logger = logging.getLogger(__name__)
- def load_tensor(bgr_img: np.ndarray, device: torch.device) -> torch.Tensor:
- """BGR numpy -> RGB float tensor (C,H,W)"""
- rgb = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
- t = torch.from_numpy(rgb).permute(2, 0, 1).float() / 255.0
- return t.to(device)
- def resize_max_side(img: np.ndarray, max_side: int) -> Tuple[np.ndarray, float]:
- """保持宽高比缩放,使长边不超过 max_side"""
- h, w = img.shape[:2]
- scale = min(max_side / max(h, w), 1.0)
- if scale == 1.0:
- return img.copy(), 1.0
- new_w, new_h = int(round(w * scale)), int(round(h * scale))
- return cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA), scale
- def center_crop(img: np.ndarray, ratio: float) -> np.ndarray:
- """按中心裁剪 ratio 比例的区域"""
- h, w = img.shape[:2]
- new_w, new_h = int(w * ratio), int(h * ratio)
- x0, y0 = (w - new_w) // 2, (h - new_h) // 2
- return img[y0:y0 + new_h, x0:x0 + new_w]
- def median_point_from_matches(
- mkpts_panorama: np.ndarray,
- mkpts_ptz: np.ndarray,
- pano_scale: float,
- ptz_scale: float,
- ptz_offset: Tuple[int, int],
- pano_h: int,
- pano_w: int,
- ) -> Optional[Tuple[float, float, int, float, float]]:
- """
- 根据匹配点计算 PTZ 中心在全景图中的对应位置。
- 返回 (x_ratio, y_ratio, num_matches, median_confidence, spread)
- """
- if len(mkpts_panorama) == 0:
- return None
- # 把缩放后的坐标还原到原图
- mkpts_panorama_orig = mkpts_panorama / pano_scale
- mkpts_ptz_orig = (mkpts_ptz / ptz_scale) + np.array(ptz_offset)
- # 取全景坐标中位数作为 PTZ 视场中心投影
- med_x = float(np.median(mkpts_panorama_orig[:, 0]))
- med_y = float(np.median(mkpts_panorama_orig[:, 1]))
- x_ratio = np.clip(med_x / pano_w, 0.0, 1.0)
- y_ratio = np.clip(med_y / pano_h, 0.0, 1.0)
- spread = float(np.median(np.linalg.norm(mkpts_panorama_orig - np.array([med_x, med_y]), axis=1)))
- return x_ratio, y_ratio, len(mkpts_panorama), 0.0, spread
- def build_models(device: torch.device):
- extractor = SuperPoint(max_num_keypoints=2048).eval().to(device)
- matcher = LightGlue(
- features='superpoint',
- depth_confidence=-1,
- width_confidence=-1,
- ).eval().to(device)
- return extractor, matcher
- def main():
- parser = argparse.ArgumentParser(description='LightGlue PTZ-panorama matching')
- parser.add_argument('--base', type=str,
- default='/Users/wenhongquan/Desktop/阿里云同步/项目/dnn/德胜河 AI/dsh/calibration_scan_180_360_z1',
- help='扫描结果目录')
- parser.add_argument('--pano-max-side', type=int, default=2048,
- help='全景图缩放后的长边像素')
- parser.add_argument('--ptz-max-side', type=int, default=1280,
- help='PTZ 图缩放后的长边像素')
- parser.add_argument('--ptz-crop-ratio', type=float, default=0.6,
- help='仅使用 PTZ 中心区域进行匹配')
- parser.add_argument('--min-matches', type=int, default=20,
- help='最少匹配点数才认为有效')
- parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'mps', 'cuda'],
- help='计算设备')
- parser.add_argument('--output-suffix', type=str, default='lightglue',
- help='输出文件后缀')
- args = parser.parse_args()
- base = Path(args.base)
- ptz_dir = base / 'ptz_images'
- pano_path = base / 'panorama.jpg'
- raw_path = base / 'mapping_raw.json'
- device = torch.device(args.device)
- logger.info(f'Using device: {device}')
- logger.info('Loading models...')
- extractor, matcher = build_models(device)
- logger.info(f'Loading panorama: {pano_path}')
- panorama_orig = cv2.imread(str(pano_path))
- if panorama_orig is None:
- logger.error('Failed to load panorama')
- return
- pano_h, pano_w = panorama_orig.shape[:2]
- panorama_small, pano_scale = resize_max_side(panorama_orig, args.pano_max_side)
- logger.info(f'Panorama resized: {panorama_small.shape[1]}x{panorama_small.shape[0]} (scale={pano_scale:.3f})')
- logger.info('Extracting panorama features...')
- pano_tensor = load_tensor(panorama_small, device)[None]
- feats_pano = extractor({'image': pano_tensor})
- logger.info(f"Panorama keypoints: {feats_pano['keypoints'].shape[1]}")
- logger.info(f'Loading records from {raw_path}')
- with open(raw_path, 'r', encoding='utf-8') as f:
- raw_data = json.load(f)
- records = raw_data['records']
- results: List[Dict[str, Any]] = []
- summary_points = []
- for idx, rec in enumerate(records, 1):
- filename = rec['filename']
- pan = rec['pan']
- tilt = rec['tilt']
- ptz_path = ptz_dir / filename
- ptz_orig = cv2.imread(str(ptz_path))
- if ptz_orig is None:
- logger.warning(f'[{idx}/{len(records)}] Cannot read {filename}')
- rec_copy = dict(rec)
- rec_copy['matched'] = False
- results.append(rec_copy)
- continue
- # 中心裁剪 + 缩放
- ptz_crop = center_crop(ptz_orig, args.ptz_crop_ratio)
- ptz_offset = ((ptz_orig.shape[1] - ptz_crop.shape[1]) // 2,
- (ptz_orig.shape[0] - ptz_crop.shape[0]) // 2)
- ptz_small, ptz_scale = resize_max_side(ptz_crop, args.ptz_max_side)
- ptz_tensor = load_tensor(ptz_small, device)[None]
- feats_ptz = extractor({'image': ptz_tensor})
- matches = matcher({
- 'image0': feats_ptz,
- 'image1': feats_pano,
- })
- feats_ptz = rbd(feats_ptz)
- feats_pano_single = rbd(feats_pano)
- matches = rbd(matches)
- m = matches['matches']
- num_matches = len(m)
- rec_copy = dict(rec)
- rec_copy['lg_matches'] = num_matches
- if num_matches >= args.min_matches:
- kpts_ptz = feats_ptz['keypoints'][m[..., 0]].cpu().numpy()
- kpts_pano = feats_pano_single['keypoints'][m[..., 1]].cpu().numpy()
- med = median_point_from_matches(
- kpts_pano, kpts_ptz,
- pano_scale, ptz_scale, ptz_offset,
- pano_h, pano_w,
- )
- if med:
- x_ratio, y_ratio, n, conf, spread = med
- rec_copy['lg_matched'] = True
- rec_copy['lg_x_ratio'] = round(x_ratio, 4)
- rec_copy['lg_y_ratio'] = round(y_ratio, 4)
- rec_copy['lg_panorama_x'] = int(x_ratio * pano_w)
- rec_copy['lg_panorama_y'] = int(y_ratio * pano_h)
- rec_copy['lg_spread'] = round(spread, 2)
- summary_points.append((pan, tilt, int(x_ratio * pano_w), int(y_ratio * pano_h)))
- logger.info(
- f'[{idx}/{len(records)}] {filename} pan={pan} tilt={tilt:+3d} -> '
- f'({x_ratio:.3f}, {y_ratio:.3f}) matches={num_matches} spread={spread:.1f}'
- )
- else:
- rec_copy['lg_matched'] = False
- logger.info(f'[{idx}/{len(records)}] {filename} matches={num_matches} but no median')
- else:
- rec_copy['lg_matched'] = False
- logger.info(f'[{idx}/{len(records)}] {filename} matches={num_matches} < {args.min_matches}')
- results.append(rec_copy)
- # 保存详细结果
- out_json = base / f'mapping_{args.output_suffix}.json'
- with open(out_json, 'w', encoding='utf-8') as f:
- json.dump({
- 'records': results,
- 'panorama_size': {'width': pano_w, 'height': pano_h},
- 'params': vars(args),
- }, f, indent=2, ensure_ascii=False)
- logger.info(f'Saved detailed mapping: {out_json}')
- # 生成 lookup table(仅有效点)
- valid = [r for r in results if r.get('lg_matched')]
- pan_lookup = sorted([[r['lg_x_ratio'], float(r['pan'])] for r in valid], key=lambda x: x[0])
- tilt_lookup = sorted([[r['lg_y_ratio'], float(r['tilt'])] for r in valid], key=lambda x: x[0])
- lookup = {
- 'created_at': raw_data.get('created_at'),
- 'pan_lookup': pan_lookup,
- 'tilt_lookup': tilt_lookup,
- 'valid_count': len(valid),
- }
- lookup_path = base / f'lookup_table_{args.output_suffix}.json'
- with open(lookup_path, 'w', encoding='utf-8') as f:
- json.dump(lookup, f, indent=2, ensure_ascii=False)
- logger.info(f'Saved lookup table: {lookup_path} ({len(valid)}/{len(records)} valid)')
- # 可视化概览:在全景图上标记每个有效位置
- if summary_points:
- vis = panorama_orig.copy()
- for pan, tilt, cx, cy in summary_points:
- cv2.circle(vis, (cx, cy), 12, (0, 0, 255), -1)
- label = f'{pan},{tilt}'
- cv2.putText(vis, label, (cx + 10, cy), cv2.FONT_HERSHEY_SIMPLEX,
- 0.5, (0, 255, 0), 1)
- vis_path = base / f'panorama_{args.output_suffix}_matches.jpg'
- cv2.imwrite(str(vis_path), vis, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
- logger.info(f'Saved overview: {vis_path}')
- if __name__ == '__main__':
- main()
|