test_spatial_scanner.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import sys
  2. import os
  3. import tempfile
  4. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  5. import cv2
  6. import numpy as np
  7. import pytest
  8. from core.spatial_scanner import SpatialScanner
  9. from core.coord_utils import compute_sample_grid
  10. class FakePTZ:
  11. def __init__(self):
  12. self.positions = []
  13. def goto_exact_position(self, pan, tilt, zoom):
  14. self.positions.append((pan, tilt, zoom))
  15. def test_spatial_scanner_grid(tmp_path):
  16. ptz = FakePTZ()
  17. counter = {"n": 0}
  18. def frame_source():
  19. counter["n"] += 1
  20. return np.zeros((120, 160, 3), dtype=np.uint8)
  21. scanner = SpatialScanner("g1", ptz, frame_source, str(tmp_path), stabilize_time=0.0)
  22. result = scanner.run(pan_range=(0, 90), tilt_layers=(-10, 0), pan_step=30, zoom=1)
  23. assert len(result["samples"]) == 6
  24. assert result["panorama_path"] is not None
  25. assert len(ptz.positions) == 6
  26. def test_invalid_pan_step_raises_value_error(tmp_path):
  27. ptz = FakePTZ()
  28. scanner = SpatialScanner("g1", ptz, lambda: None, str(tmp_path), stabilize_time=0.0)
  29. with pytest.raises(ValueError, match="pan_step must be positive"):
  30. scanner.run(pan_range=(0, 90), tilt_layers=(-10, 0), pan_step=0)
  31. def test_compute_sample_grid_validation():
  32. with pytest.raises(ValueError, match="pan_step must be positive"):
  33. compute_sample_grid(pan_step=0)
  34. with pytest.raises(ValueError, match="tilt_layers must not be empty"):
  35. compute_sample_grid(tilt_layers=())
  36. with pytest.raises(ValueError, match="pan_range start must be less than end"):
  37. compute_sample_grid(pan_range=(180.0, 180.0))
  38. def test_cancellation_stops_early(tmp_path):
  39. ptz = FakePTZ()
  40. counter = {"n": 0}
  41. def frame_source():
  42. counter["n"] += 1
  43. if counter["n"] == 2:
  44. scanner.cancel()
  45. return np.zeros((120, 160, 3), dtype=np.uint8)
  46. scanner = SpatialScanner("g1", ptz, frame_source, str(tmp_path), stabilize_time=0.0)
  47. result = scanner.run(pan_range=(0, 60), tilt_layers=(-10, 0), pan_step=30, zoom=1)
  48. assert len(result["samples"]) < 4
  49. assert scanner.progress["state"] == "cancelled"
  50. def test_empty_sample_list_returns_no_panorama(tmp_path):
  51. ptz = FakePTZ()
  52. scanner = SpatialScanner("g1", ptz, lambda: None, str(tmp_path), stabilize_time=0.0)
  53. scanner._wait_frame = lambda timeout: None
  54. result = scanner.run(pan_range=(0, 60), tilt_layers=(-10, 0), pan_step=30, zoom=1)
  55. assert result["samples"] == []
  56. assert result["panorama_path"] is None
  57. assert scanner.progress["current"] == 0
  58. def test_progress_callback_invoked(tmp_path):
  59. ptz = FakePTZ()
  60. progress_snapshots = []
  61. def frame_source():
  62. return np.zeros((120, 160, 3), dtype=np.uint8)
  63. def progress_callback(progress):
  64. progress_snapshots.append(dict(progress))
  65. scanner = SpatialScanner("g1", ptz, frame_source, str(tmp_path), stabilize_time=0.0)
  66. result = scanner.run(
  67. pan_range=(0, 60),
  68. tilt_layers=(-10, 0),
  69. pan_step=30,
  70. zoom=1,
  71. progress_callback=progress_callback,
  72. )
  73. expected_samples = len(result["samples"])
  74. assert expected_samples > 0
  75. # 扫描开始前会报告一次初始进度,之后每成功采集一个点报告一次
  76. assert len(progress_snapshots) == expected_samples + 1
  77. assert progress_snapshots[0]["current"] == 0
  78. assert progress_snapshots[0]["state"] == "scanning"
  79. assert progress_snapshots[-1]["current"] == expected_samples
  80. assert progress_snapshots[-1]["state"] == "scanning"
  81. def test_prerun_cancellation_returns_empty_result(tmp_path):
  82. ptz = FakePTZ()
  83. scanner = SpatialScanner("g1", ptz, lambda: None, str(tmp_path), stabilize_time=0.0)
  84. scanner.cancel()
  85. result = scanner.run(pan_range=(0, 60), tilt_layers=(-10, 0), pan_step=30, zoom=1)
  86. assert result["samples"] == []
  87. assert result["panorama_path"] is None
  88. assert scanner.progress["state"] == "cancelled"
  89. def test_panorama_does_not_blend_overlapping_samples(tmp_path):
  90. """重叠区域应直接覆盖,而不是加权融合导致虚化。"""
  91. ptz = FakePTZ()
  92. calls = {"n": 0}
  93. def frame_source():
  94. calls["n"] += 1
  95. # 第一张红色,第二张蓝色,两者水平视场约 55°,在 0°/30° 处明显重叠
  96. color = (0, 0, 255) if calls["n"] == 1 else (255, 0, 0)
  97. return np.full((100, 100, 3), color, dtype=np.uint8)
  98. scanner = SpatialScanner("g1", ptz, frame_source, str(tmp_path), stabilize_time=0.0)
  99. result = scanner.run(pan_range=(0, 60), tilt_layers=(0,), pan_step=30, zoom=1)
  100. assert result["panorama_path"] is not None
  101. panorama = cv2.imread(result["panorama_path"])
  102. assert panorama is not None
  103. # 在第二张图(pan=30°)的中心区域采样;直接覆盖应接近蓝色
  104. height, width = panorama.shape[:2]
  105. u = int((30 / 60) * width)
  106. v = int(((90 - 0) / 180) * height)
  107. roi = panorama[
  108. max(0, v - 50):min(height, v + 50),
  109. max(0, u - 50):min(width, u + 50),
  110. ]
  111. mask = roi.sum(axis=2) > 0
  112. assert mask.sum() > 0
  113. mean_color = roi[mask].mean(axis=0)
  114. blue = np.array([255, 0, 0], dtype=np.float32)
  115. blend = np.array([127, 0, 127], dtype=np.float32)
  116. dist_to_blue = np.linalg.norm(mean_color - blue)
  117. dist_to_blend = np.linalg.norm(mean_color - blend)
  118. # 均值应更接近蓝色,而不是红蓝融合色
  119. assert dist_to_blue < 60
  120. assert dist_to_blend > 90