Преглед изворни кода

feat: add SIFT/ORB feature matcher

wenhongquan пре 1 недеља
родитељ
комит
6e9bad84fc
2 измењених фајлова са 100 додато и 1 уклоњено
  1. 79 0
      calibration_scan_180_360/matchers.py
  2. 21 1
      calibration_scan_180_360/test_matchers.py

+ 79 - 0
calibration_scan_180_360/matchers.py

@@ -50,3 +50,82 @@ class TemplateMatcher:
         if best is None or best[2] < self.score_threshold:
         if best is None or best[2] < self.score_threshold:
             return None
             return None
         return best
         return best
+
+
+class FeatureMatcher:
+    def __init__(self,
+                 lowe_ratio: float = 0.75,
+                 min_matches: int = 10,
+                 min_inliers: int = 5,
+                 ransac_threshold: float = 4.0):
+        self.lowe_ratio = lowe_ratio
+        self.min_matches = min_matches
+        self.min_inliers = min_inliers
+        self.ransac_threshold = ransac_threshold
+
+        try:
+            self.detector = cv2.SIFT_create()
+            self.norm = cv2.NORM_L2
+            self.feature_type = 'SIFT'
+        except AttributeError:
+            self.detector = cv2.ORB_create(nfeatures=500)
+            self.norm = cv2.NORM_HAMMING
+            self.feature_type = 'ORB'
+
+        self.matcher = cv2.BFMatcher(self.norm)
+
+    def match(self, ptz_img: np.ndarray, panorama_img: np.ndarray
+              ) -> Optional[Tuple[float, float, int, int]]:
+        if ptz_img is None or panorama_img is None:
+            return None
+
+        ptz_gray = cv2.cvtColor(ptz_img, cv2.COLOR_BGR2GRAY) if len(ptz_img.shape) == 3 else ptz_img
+        pano_gray = cv2.cvtColor(panorama_img, cv2.COLOR_BGR2GRAY) if len(panorama_img.shape) == 3 else panorama_img
+        pano_h, pano_w = pano_gray.shape
+
+        # 缩小加速,但坐标按比例还原
+        ptz_scale = 1.0
+        pano_scale = 1.0
+        max_dim = 960
+        if ptz_gray.shape[1] > max_dim:
+            ptz_scale = max_dim / ptz_gray.shape[1]
+            ptz_gray = cv2.resize(ptz_gray, None, fx=ptz_scale, fy=ptz_scale, interpolation=cv2.INTER_AREA)
+        if pano_gray.shape[1] > max_dim:
+            pano_scale = max_dim / pano_gray.shape[1]
+            pano_gray = cv2.resize(pano_gray, None, fx=pano_scale, fy=pano_scale, interpolation=cv2.INTER_AREA)
+
+        kp1, des1 = self.detector.detectAndCompute(ptz_gray, None)
+        kp2, des2 = self.detector.detectAndCompute(pano_gray, None)
+
+        if des1 is None or des2 is None or len(kp1) < 4 or len(kp2) < 4:
+            return None
+
+        raw_matches = self.matcher.knnMatch(des1, des2, k=2)
+        good = []
+        for pair in raw_matches:
+            if len(pair) == 2:
+                m, n = pair
+                if m.distance < self.lowe_ratio * n.distance:
+                    good.append(m)
+
+        if len(good) < self.min_matches:
+            return None
+
+        ptz_pts = np.float32([kp1[m.queryIdx].pt for m in good])
+        pano_pts = np.float32([kp2[m.trainIdx].pt for m in good])
+
+        try:
+            _, mask = cv2.findHomography(ptz_pts, pano_pts, cv2.RANSAC, self.ransac_threshold)
+            inlier_mask = mask.ravel().astype(bool)
+            inlier_count = int(np.sum(inlier_mask))
+        except Exception:
+            return None
+
+        if inlier_count < self.min_inliers:
+            return None
+
+        inlier_pano_pts = pano_pts[inlier_mask]
+        center_x = np.mean(inlier_pano_pts[:, 0]) / pano_scale
+        center_y = np.mean(inlier_pano_pts[:, 1]) / pano_scale
+
+        return (center_x / pano_w, center_y / pano_h, inlier_count, len(good))

+ 21 - 1
calibration_scan_180_360/test_matchers.py

@@ -1,5 +1,5 @@
 import numpy as np
 import numpy as np
-from matchers import TemplateMatcher
+from matchers import TemplateMatcher, FeatureMatcher
 
 
 
 
 def test_template_matcher_finds_known_location():
 def test_template_matcher_finds_known_location():
@@ -24,3 +24,23 @@ def test_template_matcher_finds_known_location():
     assert 0.45 <= x_ratio <= 0.55
     assert 0.45 <= x_ratio <= 0.55
     assert 0.45 <= y_ratio <= 0.55
     assert 0.45 <= y_ratio <= 0.55
     assert score > 0.8
     assert score > 0.8
+
+
+def test_feature_matcher_finds_known_location():
+    # 用真实纹理更丰富的图案,避免弱纹理下 ORB 关键点不足
+    np.random.seed(1)
+    panorama = np.random.randint(0, 255, (200, 400, 3), dtype=np.uint8)
+    # 在中间嵌入一块带纹理的 PTZ 区域
+    patch = np.random.randint(0, 255, (40, 40, 3), dtype=np.uint8)
+    panorama[80:120, 180:220] = patch
+
+    ptz = panorama[80:120, 180:220].copy()
+
+    matcher = FeatureMatcher(lowe_ratio=0.8, min_matches=4, min_inliers=4)
+    result = matcher.match(ptz, panorama)
+
+    assert result is not None
+    x_ratio, y_ratio, inliers, matches = result
+    assert 0.45 <= x_ratio <= 0.55
+    assert 0.45 <= y_ratio <= 0.55
+    assert inliers >= 4