Эх сурвалжийг харах

feat: add mapping model with fusion, RANSAC and lookup tables

wenhongquan 1 долоо хоног өмнө
parent
commit
f69b9cf44d

+ 348 - 0
calibration_scan_180_360/mapping_model.py

@@ -0,0 +1,348 @@
+from typing import Dict, List, Optional, Tuple
+import math
+import numpy as np
+
+
+class MappingModel:
+    def __init__(self):
+        self.pan_offset = 0.0
+        self.pan_scale_x = 0.0
+        self.pan_scale_y = 0.0
+        self.tilt_offset = 0.0
+        self.tilt_scale_x = 0.0
+        self.tilt_scale_y = 0.0
+        self.rms_error = 0.0
+        self.pan_lookup: List[Tuple[float, float]] = []
+        self.tilt_lookup: List[Tuple[float, float]] = []
+        self.panorama_width = 3840
+        self.panorama_height = 1080
+
+    @staticmethod
+    def _angular_diff(a: float, b: float) -> float:
+        diff = a - b
+        while diff > 180:
+            diff -= 360
+        while diff < -180:
+            diff += 360
+        return diff
+
+    @staticmethod
+    def _unwrap_pan_angles(pan_values: np.ndarray) -> np.ndarray:
+        if len(pan_values) == 0:
+            return pan_values
+        ref = float(np.median(pan_values))
+        unwrapped = pan_values.astype(float).copy()
+        for i in range(len(unwrapped)):
+            diff = unwrapped[i] - ref
+            while diff > 180:
+                unwrapped[i] -= 360
+                diff = unwrapped[i] - ref
+            while diff < -180:
+                unwrapped[i] += 360
+                diff = unwrapped[i] - ref
+        return unwrapped
+
+    def fit(self, records: List[Dict]) -> bool:
+        valid = [r for r in records if 'x_ratio' in r and 'y_ratio' in r]
+        if len(valid) < 4:
+            return False
+
+        # RANSAC 过滤
+        inlier_mask = self._ransac_filter(valid)
+        inliers = [valid[i] for i in range(len(valid)) if inlier_mask[i]]
+        if len(inliers) < 4:
+            inliers = valid
+
+        # 最小二乘拟合
+        self._fit_linear(inliers)
+
+        # 生成查找表
+        self._build_lookups(inliers)
+
+        # 计算 RMS
+        self.rms_error = self._calculate_rms_error(inliers)
+        return True
+
+    def _ransac_filter(self, records: List[Dict],
+                       iterations: int = 200,
+                       threshold: float = 15.0) -> np.ndarray:
+        n = len(records)
+        if n < 8:
+            return np.ones(n, dtype=bool)
+
+        x = np.array([r['x_ratio'] for r in records])
+        y = np.array([r['y_ratio'] for r in records])
+        pan = self._unwrap_pan_angles(np.array([r['pan'] for r in records]))
+        tilt = np.array([r['tilt'] for r in records])
+
+        best_inliers = np.ones(n, dtype=bool)
+        best_count = 0
+        rng = np.random.RandomState(42)
+
+        for _ in range(iterations):
+            idx = rng.choice(n, 4, replace=False)
+            A = np.ones((4, 3))
+            A[:, 1] = x[idx]
+            A[:, 2] = y[idx]
+            try:
+                pan_params, _, _, _ = np.linalg.lstsq(A, pan[idx], rcond=None)
+                tilt_params, _, _, _ = np.linalg.lstsq(A, tilt[idx], rcond=None)
+            except np.linalg.LinAlgError:
+                continue
+
+            pred_pan = pan_params[0] + pan_params[1] * x + pan_params[2] * y
+            pred_tilt = tilt_params[0] + tilt_params[1] * x + tilt_params[2] * y
+            pan_err = np.array([self._angular_diff(float(pred_pan[i]), float(pan[i])) for i in range(n)])
+            tilt_err = pred_tilt - tilt
+            errors = np.sqrt(pan_err ** 2 + tilt_err ** 2)
+            inliers = errors < threshold
+            count = int(np.sum(inliers))
+            if count > best_count:
+                best_count = count
+                best_inliers = inliers
+
+        return best_inliers
+
+    def _fit_linear(self, records: List[Dict]) -> None:
+        x = np.array([r['x_ratio'] for r in records])
+        y = np.array([r['y_ratio'] for r in records])
+        pan = self._unwrap_pan_angles(np.array([r['pan'] for r in records]))
+        tilt = np.array([r['tilt'] for r in records])
+
+        A = np.ones((len(records), 3))
+        A[:, 1] = x
+        A[:, 2] = y
+        pan_params, _, _, _ = np.linalg.lstsq(A, pan, rcond=None)
+        tilt_params, _, _, _ = np.linalg.lstsq(A, tilt, rcond=None)
+
+        self.pan_offset = float(pan_params[0])
+        self.pan_scale_x = float(pan_params[1])
+        self.pan_scale_y = float(pan_params[2])
+        self.tilt_offset = float(tilt_params[0])
+        self.tilt_scale_x = float(tilt_params[1])
+        self.tilt_scale_y = float(tilt_params[2])
+
+        # 系数异常则回退简化模型
+        if (abs(self.pan_scale_x) > 500 or abs(self.pan_scale_y) > 500 or
+                abs(self.tilt_scale_x) > 300 or abs(self.tilt_scale_y) > 300):
+            A_pan = np.ones((len(records), 2))
+            A_pan[:, 1] = x
+            A_tilt = np.ones((len(records), 2))
+            A_tilt[:, 1] = y
+            pan_params_s, _, _, _ = np.linalg.lstsq(A_pan, pan, rcond=None)
+            tilt_params_s, _, _, _ = np.linalg.lstsq(A_tilt, tilt, rcond=None)
+            self.pan_offset = float(pan_params_s[0])
+            self.pan_scale_x = float(pan_params_s[1])
+            self.pan_scale_y = 0.0
+            self.tilt_offset = float(tilt_params_s[0])
+            self.tilt_scale_x = 0.0
+            self.tilt_scale_y = float(tilt_params_s[1])
+
+    def _build_lookups(self, records: List[Dict]) -> None:
+        grid_size = 0.05
+
+        # pan_lookup
+        x_buckets: Dict[float, List[Tuple[float, float]]] = {}
+        for r in records:
+            x_key = round(r['x_ratio'] / grid_size) * grid_size
+            x_buckets.setdefault(x_key, []).append((r['pan'], 1.0))
+
+        raw = []
+        for x_key in sorted(x_buckets.keys()):
+            pans = [p for p, _ in x_buckets[x_key]]
+            raw.append((x_key, float(np.median(pans)), len(pans)))
+
+        filtered = self._filter_continuous_monotonic(raw)
+        self.pan_lookup = [(x, pan) for x, pan, _ in filtered]
+
+        # tilt_lookup:只在 pan_lookup 有效的 x 附近取点
+        pan_valid_x = {x for x, _ in self.pan_lookup}
+        pan_tolerance = grid_size * 1.5
+        valid_for_tilt = []
+        for r in records:
+            if any(abs(r['x_ratio'] - vx) <= pan_tolerance for vx in pan_valid_x):
+                valid_for_tilt.append(r)
+
+        y_buckets: Dict[float, List[float]] = {}
+        for r in valid_for_tilt:
+            y_key = round(r['y_ratio'] / grid_size) * grid_size
+            y_buckets.setdefault(y_key, []).append(r['tilt'])
+
+        self.tilt_lookup = [(y_key, float(np.median(tilts))) for y_key, tilts in sorted(y_buckets.items())]
+
+    def _filter_continuous_monotonic(self, entries: List[Tuple[float, float, float]],
+                                     max_step: float = 60.0) -> List[Tuple[float, float, float]]:
+        n = len(entries)
+        if n <= 2:
+            return [(x, pan % 360, w) for x, pan, w in entries]
+
+        best_result = []
+        for direction in ['decreasing', 'increasing']:
+            dp = [1] * n
+            parent = [-1] * n
+            for i in range(1, n):
+                for j in range(i):
+                    diff = entries[i][1] - entries[j][1]
+                    while diff > 180:
+                        diff -= 360
+                    while diff < -180:
+                        diff += 360
+                    if direction == 'decreasing':
+                        ok = diff <= 0 and abs(diff) <= max_step
+                    else:
+                        ok = diff >= 0 and abs(diff) <= max_step
+                    if ok and dp[j] + 1 > dp[i]:
+                        dp[i] = dp[j] + 1
+                        parent[i] = j
+            end = max(range(n), key=lambda i: dp[i])
+            seq = []
+            idx = end
+            while idx >= 0:
+                seq.append(idx)
+                idx = parent[idx]
+            seq.reverse()
+            result = self._unwrap_entries(entries, seq)
+            if len(result) > len(best_result):
+                best_result = result
+
+        if len(best_result) < 3 and n >= 3:
+            for wider_step in [90, 120, 180]:
+                for direction in ['decreasing', 'increasing']:
+                    # 同上逻辑,用 wider_step
+                    dp = [1] * n
+                    parent = [-1] * n
+                    for i in range(1, n):
+                        for j in range(i):
+                            diff = entries[i][1] - entries[j][1]
+                            while diff > 180:
+                                diff -= 360
+                            while diff < -180:
+                                diff += 360
+                            if direction == 'decreasing':
+                                ok = diff <= 0 and abs(diff) <= wider_step
+                            else:
+                                ok = diff >= 0 and abs(diff) <= wider_step
+                            if ok and dp[j] + 1 > dp[i]:
+                                dp[i] = dp[j] + 1
+                                parent[i] = j
+                    end = max(range(n), key=lambda i: dp[i])
+                    seq = []
+                    idx = end
+                    while idx >= 0:
+                        seq.append(idx)
+                        idx = parent[idx]
+                    seq.reverse()
+                    result = self._unwrap_entries(entries, seq)
+                    if len(result) > len(best_result):
+                        best_result = result
+                if len(best_result) >= 3:
+                    break
+
+        if not best_result:
+            return [(x, pan % 360, w) for x, pan, w in entries]
+        return best_result
+
+    def _unwrap_entries(self, entries: List[Tuple[float, float, float]],
+                        indices: List[int]) -> List[Tuple[float, float, float]]:
+        result = []
+        prev = None
+        for idx in indices:
+            x, pan, w = entries[idx]
+            if prev is None:
+                unwrapped = pan
+            else:
+                diff = pan - prev
+                while diff > 180:
+                    pan -= 360
+                    diff = pan - prev
+                while diff < -180:
+                    pan += 360
+                    diff = pan - prev
+                unwrapped = pan
+            prev = unwrapped
+            result.append((x, unwrapped, w))
+        return result
+
+    def _calculate_rms_error(self, records: List[Dict]) -> float:
+        total = 0.0
+        for r in records:
+            pred_pan, pred_tilt = self.transform(
+                int(r['x_ratio'] * self.panorama_width),
+                int(r['y_ratio'] * self.panorama_height)
+            )
+            pan_err = self._angular_diff(pred_pan, r['pan'])
+            tilt_err = pred_tilt - r['tilt']
+            total += pan_err ** 2 + tilt_err ** 2
+        return math.sqrt(total / len(records))
+
+    def transform(self, panorama_x: int, panorama_y: int) -> Tuple[float, float]:
+        x_ratio = panorama_x / self.panorama_width
+        y_ratio = panorama_y / self.panorama_height
+
+        if self.pan_lookup:
+            pan = self._interp_lookup(self.pan_lookup, x_ratio)
+        else:
+            pan = self.pan_offset + self.pan_scale_x * x_ratio + self.pan_scale_y * y_ratio
+
+        if self.tilt_lookup:
+            tilt = self._interp_lookup(self.tilt_lookup, y_ratio)
+        else:
+            tilt = self.tilt_offset + self.tilt_scale_x * x_ratio + self.tilt_scale_y * y_ratio
+
+        pan = pan % 360
+        tilt = max(-90, min(90, tilt))
+        return pan, tilt
+
+    def _interp_lookup(self, lookup: List[Tuple[float, float]], ratio: float) -> float:
+        if not lookup:
+            return 0.0
+        if len(lookup) == 1:
+            return lookup[0][1]
+        if ratio <= lookup[0][0]:
+            return lookup[0][1]
+        if ratio >= lookup[-1][0]:
+            return lookup[-1][1]
+
+        lo, hi = 0, len(lookup) - 1
+        while lo < hi - 1:
+            mid = (lo + hi) // 2
+            if lookup[mid][0] <= ratio:
+                lo = mid
+            else:
+                hi = mid
+
+        x0, v0 = lookup[lo]
+        x1, v1 = lookup[hi]
+        if abs(x1 - x0) < 1e-10:
+            return v0
+        t = (ratio - x0) / (x1 - x0)
+        return v0 + t * (v1 - v0)
+
+    def to_dict(self) -> Dict:
+        return {
+            'pan_offset': self.pan_offset,
+            'pan_scale_x': self.pan_scale_x,
+            'pan_scale_y': self.pan_scale_y,
+            'tilt_offset': self.tilt_offset,
+            'tilt_scale_x': self.tilt_scale_x,
+            'tilt_scale_y': self.tilt_scale_y,
+            'rms_error': self.rms_error,
+            'pan_lookup': self.pan_lookup,
+            'tilt_lookup': self.tilt_lookup,
+            'overlap_ranges': [{
+                'pan_start': 180.0,
+                'pan_end': 340.0,
+                'tilt_start': -35.0,
+                'tilt_end': 45.0,
+            }],
+            'mount_type': 'wall',
+            'tilt_flip': False,
+            'pan_flip': False,
+            'generated_from': 'ptz_panorama_fused_matching',
+            'note': 'x_ratio,y_ratio 为全景图归一化坐标;transform 输入像素坐标 3840x1080',
+        }
+
+    def save(self, path: str) -> None:
+        import json
+        with open(path, 'w', encoding='utf-8') as f:
+            json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)

+ 33 - 0
calibration_scan_180_360/test_mapping_model.py

@@ -0,0 +1,33 @@
+import numpy as np
+from mapping_model import MappingModel
+
+
+def test_mapping_model_fits_linear_pan_tilt():
+    # 构造 9 个校准点:x_ratio 与 pan 线性递减,y_ratio 与 tilt 线性递减
+    records = []
+    for i, pan in enumerate([340, 300, 260, 220, 180]):
+        for j, tilt in enumerate([45, 5, -35]):
+            x_ratio = (360 - pan) / 180.0  # 0.111 ~ 1.0
+            y_ratio = (tilt + 35) / 80.0   # 0.0 ~ 1.0
+            records.append({
+                'pan': float(pan),
+                'tilt': float(tilt),
+                'x_ratio': x_ratio,
+                'y_ratio': y_ratio,
+                'confidence': 'high',
+            })
+
+    model = MappingModel()
+    model.fit(records)
+
+    pan, tilt = model.transform(1920, 540)  # 中心点
+    # 根据 x_ratio=(360-pan)/180 的线性关系,中心 x=0.5 对应 pan≈270
+    assert 250 <= pan <= 290
+    assert -10 <= tilt <= 15
+
+    # 验证保存/加载格式
+    data = model.to_dict()
+    assert 'pan_lookup' in data
+    assert 'tilt_lookup' in data
+    assert len(data['pan_lookup']) >= 3
+    assert len(data['tilt_lookup']) >= 3